Fast, general, and tested differentiable structured prediction in PyTorch

Overview

Torch-Struct: Structured Prediction Library

Tests Coverage Status

A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications.

  • HMM / LinearChain-CRF
  • HSMM / SemiMarkov-CRF
  • Dependency Tree-CRF
  • PCFG Binary Tree-CRF
  • ...

Designed to be used as efficient batched layers in other PyTorch code.

Tutorial paper describing methodology.

Getting Started

!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
# Optional CUDA kernels for FastLogSemiring
!pip install -qU git+https://github.com/harvardnlp/genbmm
# For plotting.
!pip install -q matplotlib
import torch
from torch_struct import DependencyCRF, LinearChainCRF
import matplotlib.pyplot as plt
def show(x): plt.imshow(x.detach())
# Make some data.
vals = torch.zeros(2, 10, 10) + 1e-5
vals[:, :5, :5] = torch.rand(5)
vals[:, 5:, 5:] = torch.rand(5) 
dist = DependencyCRF(vals.log())
show(dist.log_potentials[0])

png

# Compute marginals
show(dist.marginals[0])

png

# Compute argmax
show(dist.argmax.detach()[0])

png

# Compute scoring and enumeration (forward / inside)
log_partition = dist.partition
max_score = dist.log_prob(dist.argmax)
# Compute samples 
show(dist.sample((1,)).detach()[0, 0])

png

# Padding/Masking built into library.
dist = DependencyCRF(vals, lengths=torch.tensor([10, 7]))
show(dist.marginals[0])
plt.show()
show(dist.marginals[1])

png

png

# Many other structured prediction approaches
chain = torch.zeros(2, 10, 10, 10) + 1e-5
chain[:, :, :, :] = vals.unsqueeze(-1).exp()
chain[:, :, :, :] += torch.eye(10, 10).view(1, 1, 10, 10) 
chain[:, 0, :, 0] = 1
chain[:, -1,9, :] = 1
chain = chain.log()

dist = LinearChainCRF(chain)
show(dist.marginals.detach()[0].sum(-1))

png

Library

Full docs: http://nlp.seas.harvard.edu/pytorch-struct/

Current distributions implemented:

  • LinearChainCRF
  • SemiMarkovCRF
  • DependencyCRF
  • NonProjectiveDependencyCRF
  • TreeCRF
  • NeuralPCFG / NeuralHMM

Each distribution includes:

  • Argmax, sampling, entropy, partition, masking, log_probs, k-max

Extensions:

  • Integration with torchtext, pytorch-transformers, dgl
  • Adapters for generative structured models (CFG / HMM / HSMM)
  • Common tree structured parameterizations TreeLSTM / SpanLSTM

Low-level API:

Everything implemented through semiring dynamic programming.

  • Log Marginals
  • Max and MAP computation
  • Sampling through specialized backprop
  • Entropy and first-order semirings.

Examples

Citation

@misc{alex2020torchstruct,
    title={Torch-Struct: Deep Structured Prediction Library},
    author={Alexander M. Rush},
    year={2020},
    eprint={2002.00876},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}

This work was partially supported by NSF grant IIS-1901030.

Comments
  • add tests for CKY

    add tests for CKY

    This PR fixes several bugs in k-best parsing with dist.topk() and includes a simple test to test the function.

    I made incremental changes so that existing modules relying on the CKY will not be affected.

    opened by zhaoyanpeng 8
  • 1st order cky implementation

    1st order cky implementation

    Hi,

    I'd like to contribute this implementation of a first-order cky-style crf with anchored rule potentials: $\phi[i,j,k,A,B,C] := \phi(A_{i,j} \rightarrow B_{i,k}, C{k+1,j})$.

    I also added code to the _Struct class that allows calculating marginals even if input tensors don't require a gradient (i.e., after model.eval())

    Please let me know if you'd like to see any changes.

    Thanks, Tom

    opened by teffland 6
  • Mini-batch setting with Semi Markov CRF

    Mini-batch setting with Semi Markov CRF

    I encounter learning instability when using a batch size > 1 with the semi-markovian CRF (loss goes to very large negative number), even when explicitly providing "lengths". I think the bug comes from the masking. The model train well when setting batch size 1.

    opened by urchade 5
  • Release on PyPI?

    Release on PyPI?

    Is there any interest on releasing pytorch-struct (and genbmm) on the official Python Package Index?

    I ran into this because I distribute my constituency parser on PyPI, and I just recently pushed a new version that depends on pytorch-struct: https://pypi.org/project/benepar/0.2.0a0/

    It turns out that packages on PyPI aren't allowed to depend on packages only hosted on github, so users of my parser can't just pip install benepar and have it work right away.

    opened by nikitakit 5
  • up sweep and down sweep

    up sweep and down sweep

    I'm interested in the parallel scan algorithm for the linear-chain CRF.

    I read the related paper in the tutorial and found that there are two steps: up sweep and down sweep in order to obtain all-prefix-sum.

    I think in this case, we use that algorithm to obtain all Z(x) with different lengths in a batch. But seems I couldn't find out the down sweep code in the repo. Can you point me out there?

    opened by allanj 5
  • [Bug] Implementation of Eisner's algorithm does not restrict the root number to 1

    [Bug] Implementation of Eisner's algorithm does not restrict the root number to 1

    Hey, I found that your implementation of Eisner's algorithm admits arbitrary root number, which is a very severe bug since dependency parsing usually has only one root token.

    In your DepTree.dp() method, you make a conversion to let the root token as the first token in the sentence. Imagine that the root x{0} attacks word x_{i}, I_{0,0} + C_{1, i} = I_{0, i} and I_{0, i} + C_{i,j} = C_{0, j} for some j < L where L is the length of sentence. Now complete span C_{0, j} still have opportunity to attach a new word x_{k} for j< k<=L, making multiple root attachment possible.

    Fortunately, I made some changes to your codes to restrict the root number to 1.

    ` def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True): if arc_scores_in.dim() not in (3, 4): raise ValueError("potentials must have dim of 3 (unlabeled) or 4 (labeled)")

        labeled = arc_scores_in.dim() == 4
        semiring = self.semiring
        # arc_scores_in = _convert(arc_scores_in)
        arc_scores_in, batch, N, lengths = self._check_potentials(
            arc_scores_in, lengths
        )
        arc_scores_in.requires_grad_(True)
        arc_scores = semiring.sum(arc_scores_in) if labeled else arc_scores_in
        alpha = [
            [
                [
                    Chart((batch, N, N), arc_scores, semiring, cache=cache)
                    for _ in range(2)
                ]
                for _ in range(2)
            ]
            for _ in range(2)
        ]
    
        semiring.one_(alpha[A][C][L].data[:, :, :, 0].data)
        semiring.one_(alpha[A][C][R].data[:, :, :, 0].data)
        semiring.one_(alpha[B][C][L].data[:, :, :, -1].data)
        semiring.one_(alpha[B][C][R].data[:, :, :, -1].data)
    
    
        for k in range(1, N):
            f = torch.arange(N - k), torch.arange(k, N)
            ACL = alpha[A][C][L][: N - k, :k]
            ACR = alpha[A][C][R][: N - k, :k]
            BCL = alpha[B][C][L][k:, N - k :]
            BCR = alpha[B][C][R][k:, N - k :]
            x = semiring.dot(ACR, BCL)
            arcs_l = semiring.times(x, arc_scores[:, :, f[1], f[0]])
            alpha[A][I][L][: N - k, k] = arcs_l
            alpha[B][I][L][k:N, N - k - 1] = arcs_l
            arcs_r = semiring.times(x, arc_scores[:, :, f[0], f[1]])
            alpha[A][I][R][:N - k, k] = arcs_r
            alpha[B][I][R][k:N, N - k - 1] = arcs_r
            AIR = alpha[A][I][R][: N - k, 1 : k + 1]
            BIL = alpha[B][I][L][k:, N - k - 1 : N - 1]
            new = semiring.dot(ACL, BIL)
            alpha[A][C][L][: N - k, k] = new
            alpha[B][C][L][k:N, N - k - 1] = new
            new = semiring.dot(AIR, BCR)
            alpha[A][C][R][: N - k, k] = new
            alpha[B][C][R][k:N, N - k - 1] = new
    
        root_incomplete_span = semiring.times(alpha[A][C][L][0, :], arc_scores[:, :, torch.arange(N), torch.arange(N)])
        root =  [ Chart((batch,), arc_scores, semiring, cache=cache) for _ in range(N)]
        for k in range(N):
            AIR = root_incomplete_span[:, :, :k+1]
            BCR = alpha[B][C][R][k, N - (k+1):]
            root[k] = semiring.dot(AIR, BCR)
        v = torch.stack([root[l-1][:,i] for i, l in enumerate(lengths)], dim=1)
        return v, [arc_scores_in], alpha
    

    `

    Basically, I don't treat the first token as root anymore. I handle the root token just after the for-loop, so you may need handle the length variable. (length = length-1, root no longer be treated as part of sentence) . I tested the modified code and found it bug-free

    opened by sustcsonglin 4
  • Inference for the HMM model

    Inference for the HMM model

    Hello! I was playing with the HMM distribution and I obtained some results that I don't really understand. More precisely, I've set the following parameters

    t = torch.tensor([[0.99, 0.01], [0.01, 0.99]]).log()
    e = torch.tensor([[0.50, 0.50], [0.50, 0.50]]).log()
    i = torch.tensor(np.array([0.99, 0.01])).log()
    x = torch.randint(0, 2, size=(1, 8))
    

    and I was expecting the model to stay in the hidden state 0 regardless of the observed data x – it starts in state 0 and the transition matrix makes it very likely to maintain it. But when plotting the argmax, it appears that the model jumps from one state to the other:

    def show_chain(chain):
        plt.imshow(chain.detach().sum(-1).transpose(0, 1))
    
    dist = torch_struct.HMM(t, e, i, x)
    show_chain(dist.argmax[0])
    

    image

    I must be missing something obvious; but shouldn't dist.argmax correspond to argmax_z p(z | x, Θ)? Thank you!

    opened by danoneata 4
  • DependencyCRF partition function broken

    DependencyCRF partition function broken

    Getting the following in-place operation error when using the DependencyCRF:

    B,N = 3,50
    phi = torch.randn(B,N,N)
    DependencyCRF(phi).partition
    
    /usr/local/lib/python3.7/dist-packages/torch_struct/deptree.py in _check_potentials(self, arc_scores, lengths)
        121         arc_scores = semiring.convert(arc_scores)
        122         for b in range(batch):
    --> 123             semiring.zero_(arc_scores[:, b, lengths[b] + 1 :, :])
        124             semiring.zero_(arc_scores[:, b, :, lengths[b] + 1 :])
        125 
    
    /usr/local/lib/python3.7/dist-packages/torch_struct/semirings/semirings.py in zero_(xs)
        124     @staticmethod
        125     def zero_(xs):
    --> 126         return xs.fill_(-1e5)
        127 
        128     @staticmethod
    
    RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.
    
    opened by teffland 3
  • [Question] How to compute a marginal probability over a (contiguous) set of nodes?

    [Question] How to compute a marginal probability over a (contiguous) set of nodes?

    Hi.

    Thank you for the great library. I have one question that I hope you could help with.

    How can I compute a marginal probability over a (contiguous) set of nodes? Right now, I am using your LinearChain-CRF to do NER. In addition to the best sequence itself, I also need to compute the model’s confidence in its predicted labeling over a segment of input. For example, what is the probability that a span of tokens constitute a person name?

    I read your example and see how you get the marginal prob for each individual node. But I was not quite sure how to compute the marginal prob over a subset of nodes. If you could give any hint, it would be great.

    Thank you.

    opened by kimdev95 3
  • Get the score of dist.topk()

    Get the score of dist.topk()

    The topk() function returns top k predictions from the distribution, how to easily get the corresponding score of each prediction?

    By the way, when sentence lengths are short and the k value of topk is large, how to know the number of predictions that are valid? For the example in DependencyCRF, when sentence length is 2 and k is 5, only the top 3 predictions are valid I think.

    opened by wangxinyu0922 3
  • Labeled projective dependency CRF

    Labeled projective dependency CRF

    This is work in progress and isn't ready to merge yet.

    This seems to work for partition, but argmax and marginals don't return as I expect. Both return tensor of shape (B, N, N); I'd expect them to return (B, N, N, L) tensors instead. Any advice?

    opened by kmkurn 3
  • [Question] How to apply pytorch-struct for 2 dimensional data?

    [Question] How to apply pytorch-struct for 2 dimensional data?

    I could find examples of pytorch struct usage for 1d sequence data like text or video frame. But I'm trying to parse tables structure in pdf documents.

    Could you provide some hints where to start?

    opened by YuriyPryyma 4
  • end_class support for Autoregressive

    end_class support for Autoregressive

    end_class is not used for the Autoregressive module: https://github.com/harvardnlp/pytorch-struct/blob/7146de5659ff17ad7be53023c025ffd099866412/torch_struct/autoregressive.py#L49

    opened by urchade 1
  • Update examples to use newer torchtext APIs

    Update examples to use newer torchtext APIs

    opened by erip 2
  • Instable learning with SemiMarkov CRF

    Instable learning with SemiMarkov CRF

    HI,

    First, thank you for fixing #110 (@da03), the SemiCRF works better now, I was able to get good results on span extraction tasks. However, I still encounter a learning instability where the loss (neg logprob) gets negative after several steps (and the accuracy starts to drop). The same problem occurs with batch_size = 1. Below I put the learning curve (f1_score and log loss).

    (Maybe the bug comes from the masking of spans where (length, length + span_with) and length + span_with > length, but I am not sure.)

    Edit: I created a test and it seems that the masking is good. Maybe the log_prob computation or the to_parts function ?

    train_loss score

    opened by urchade 0
  • fix bug- missing assignment of spans from sentCFG in documentation

    fix bug- missing assignment of spans from sentCFG in documentation

    Noticed a small bug in the documentation and example of SentCFG. The return of dist.argmax is (terms, rules, init, spans), but example in documentation only assigns (term, rules, init) and gives dim mismatch. As such when running the example it breaks. This fix resolves this issue.

    opened by jdegange 0
Releases(v0.5)
PyTorch Implementation of Sparse DETR

Sparse DETR By Byungseok Roh*, Jaewoong Shin*, Wuhyun Shin*, and Saehoon Kim at Kakao Brain. (*: Equal contribution) This repository is an official im

Kakao Brain 113 Dec 28, 2022
Implementation for paper: Self-Regulation for Semantic Segmentation

Self-Regulation for Semantic Segmentation This is the PyTorch implementation for paper Self-Regulation for Semantic Segmentation, ICCV 2021. Citing SR

Dong ZHANG 30 Nov 21, 2022
Deep Q-Learning Network in pytorch (not actively maintained)

pytoch-dqn This project is pytorch implementation of Human-level control through deep reinforcement learning and I also plan to implement the followin

Hung-Tu Chen 342 Jan 01, 2023
A toolset for creating Qualtrics-based IAT experiments

Qualtrics IAT Tool A web app for generating the Implicit Association Test (IAT) running on Qualtrics Online Web App The app is hosted by Streamlit, a

0 Feb 12, 2022
ICLR2021 (Under Review)

Self-Supervised Time Series Representation Learning by Inter-Intra Relational Reasoning This repository contains the official PyTorch implementation o

Haoyi Fan 58 Dec 30, 2022
Pytorch Implementation of Residual Vision Transformers(ResViT)

ResViT Official Pytorch Implementation of Residual Vision Transformers(ResViT) which is described in the following paper: Onat Dalmaz and Mahmut Yurt

ICON Lab 41 Dec 08, 2022
The codes and related files to reproduce the results for Image Similarity Challenge Track 2.

ISC-Track2-Submission The codes and related files to reproduce the results for Image Similarity Challenge Track 2. Required dependencies To begin with

Wenhao Wang 89 Jan 02, 2023
Convex optimization for fun and profit.

CFMM Optimal Routing This repository contains the code needed to generate the figures used in the paper Optimal Routing for Constant Function Market M

Guillermo Angeris 183 Dec 29, 2022
Code for Fully Context-Aware Image Inpainting with a Learned Semantic Pyramid

SPN: Fully Context-Aware Image Inpainting with a Learned Semantic Pyramid Code for Fully Context-Aware Image Inpainting with a Learned Semantic Pyrami

12 Jun 27, 2022
PyTorch implementation of Trust Region Policy Optimization

PyTorch implementation of TRPO Try my implementation of PPO (aka newer better variant of TRPO), unless you need to you TRPO for some specific reasons.

Ilya Kostrikov 366 Nov 15, 2022
Large-scale language modeling tutorials with PyTorch

Large-scale language modeling tutorials with PyTorch 안녕하세요. 저는 TUNiB에서 머신러닝 엔지니어로 근무 중인 고현웅입니다. 이 자료는 대규모 언어모델 개발에 필요한 여러가지 기술들을 소개드리기 위해 마련하였으며 기본적으로

TUNiB 172 Dec 29, 2022
TensorFlowOnSpark brings TensorFlow programs to Apache Spark clusters.

TensorFlowOnSpark TensorFlowOnSpark brings scalable deep learning to Apache Hadoop and Apache Spark clusters. By combining salient features from the T

Yahoo 3.8k Jan 04, 2023
Global Pooling, More than Meets the Eye: Position Information is Encoded Channel-Wise in CNNs, ICCV 2021

Global Pooling, More than Meets the Eye: Position Information is Encoded Channel-Wise in CNNs, ICCV 2021 Global Pooling, More than Meets the Eye: Posi

Md Amirul Islam 32 Apr 24, 2022
A collection of Google research projects related to Federated Learning and Federated Analytics.

Federated Research Federated Research is a collection of research projects related to Federated Learning and Federated Analytics. Federated learning i

Google Research 483 Jan 05, 2023
OpenMMLab Video Perception Toolbox. It supports Video Object Detection (VID), Multiple Object Tracking (MOT), Single Object Tracking (SOT), Video Instance Segmentation (VIS) with a unified framework.

English | 简体中文 Documentation: https://mmtracking.readthedocs.io/ Introduction MMTracking is an open source video perception toolbox based on PyTorch.

OpenMMLab 2.7k Jan 08, 2023
A method to perform unsupervised cross-region adaptation of crop classifiers trained with satellite image time series.

TimeMatch Official source code of TimeMatch: Unsupervised Cross-region Adaptation by Temporal Shift Estimation by Joachim Nyborg, Charlotte Pelletier,

Joachim Nyborg 17 Nov 01, 2022
A simple API wrapper for Discord interactions.

Your ultimate Discord interactions library for discord.py. About | Installation | Examples | Discord | PyPI About What is discord-py-interactions? dis

james 641 Jan 03, 2023
TLDR: Twin Learning for Dimensionality Reduction

TLDR (Twin Learning for Dimensionality Reduction) is an unsupervised dimensionality reduction method that combines neighborhood embedding learning with the simplicity and effectiveness of recent self

NAVER 105 Dec 28, 2022
CLDF dataset derived from Robbeets et al.'s "Triangulation Supports Agricultural Spread" from 2021

CLDF dataset derived from Robbeets et al.'s "Triangulation Supports Agricultural Spread" from 2021 How to cite If you use these data please cite the o

Digital Linguistics 2 Dec 20, 2021
Code for the Population-Based Bandits Algorithm, presented at NeurIPS 2020.

Population-Based Bandits (PB2) Code for the Population-Based Bandits (PB2) Algorithm, from the paper Provably Efficient Online Hyperparameter Optimiza

Jack Parker-Holder 22 Nov 16, 2022