Repository for paper "Non-intrusive speech intelligibility prediction from discrete latent representations"

Overview

Non-Intrusive Speech Intelligibility Prediction from Discrete Latent Representations

Official repository for paper "Non-Intrusive Speech Intelligibility Prediction from Discrete Latent Representations".

This public repository is a work in progress! Results here bear no resemblance to results in the paper!

We predict the intelligibility of binaural speech signals by first extracting latent representations from raw audio. Then, a lightweight predictor over these latent representations can be trained. This results in improved performance over predicting on spectral features of the audio, despite the feature extractor not being explicitly trained for this task. In certain cases, a single layer is sufficient for strong correlations between the predictions and the ground-truth scores.

This repository contains:

  • vqcpc/ - Module for VQCPC model in PyTorch
  • stoi/ - Module for Small and SeqPool predictor model in PyTorch
  • data.py - File containing various PyTorch custom datasets
  • main-vqcpc.py - Script for VQCPC training
  • create-latents.py - Script for generating latent dataset from trained VQCPC
  • plot-latents.py - Script for visualizing extracted latent representations
  • main-stoi.py - Script for STOI predictor training
  • main-test.py - Script for evaluating models
  • compute-correlations.py - Script for computing metrics for many models
  • checkpoints/ - trained checkpoints of VQCPC and STOI predictor models
  • config/ - Directory containing various configuration files for experiments
  • results/ - Directory containing official results from experiments
  • dataset/ - Directory containing metadata files for the dataset
  • data-generator/ - Directory containing dataset generation scripts (MATLAB)

All models are implemented in PyTorch. The training scripts are implemented using ptpt - a lightweight framework around PyTorch.

Visualisation of binaural waveform, predicted per-frame STOI, and latent representation: Visualisation of binaural waveform, predicted per-frame STOI, and latent representation.

Usage

VQ-CPC Training

Begin VQ-CPC training using the configuration defined in config.toml:

python main-vqcpc.py --cfg-path config-path.toml

Other useful arguments:

--resume            # resume from specified checkpoint
--no-save           # do not save training progress (useful for debugging)
--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--nb-workers        # number of workers for for data loading (default: 8)
--detect-anomaly    # detect autograd anomalies and terminate if encountered
--seed              # random seed (default: 12345)

Latent Dataset Generation

Begin latent dataset generation using pre-trained VQCPC model-checkpoint.pt from dataset wav-dataset and output to latent-dataset using configuration defined in config.toml:

python create-latents.py model-checkpoint.pt wav-dataset latent-dataset --cfg-path config.toml

As above, but distributed across n processes with script rank r:

python create-latents.py model-checkpoint.pt wav-dataset latent-dataset --cfg-path config.toml --array-size n --array-rank r

Other useful arguments:

--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--no-tqdm           # disable progress bars
--detect-anomaly    # detect autograd anomalies and terminate if encountered
-n                  # alias for `--array-size`
-r                  # alias for `--array-rank`

Latent Plotting

Begin interactive VQCPC latent visualisation script using pre-trained model model-checkpoint.pt on dataset wav-dataset using configuration defined in config.toml:

python plot-latents.py model-checkpoint.pt wav-dataset --cfg-path config.toml

If you additionally have a pre-trained, per-frame STOI score predictor (not SeqPool predictor) you can specify the checkpoint stoi-checkpoint.pt and additional configuration stoi-config.toml, you can plot per-frame scores alongside the waveform and latent features:

python plot-latents.py model-checkpoint.pt wav-dataset --cfg-path config.toml --stoi stoi-checkpoint.pt --stoi-cfg stoi-config.toml

Other useful arguments:

--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--cmap              # define matplotlib colourmap
--style             # define matplotlib style

STOI Predictor Training

Begin intelligibility score predictor training script using configuration in config.toml:

python main-stoi.py --cfg-path config.toml

Other useful arguments:

--resume            # resume from specified checkpoint
--no-save           # do not save training progress (useful for debugging)
--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--nb-workers        # number of workers for for data loading (default: 8)
--detect-anomaly    # detect autograd anomalies and terminate if encountered
--seed              # random seed (default: 12345)

Predictor Evaluation

Begin evaluation of a pre-trained STOI score predictor using checkpoint stoi-checkpoint.pt on dataset dataset-root using configuration in stoi-config.toml:

python main-test.py stoi-checkpoint.pt dataset-root --cfg-path stoi-config.toml

Other useful arguments:

--no-save           # do not save training progress (useful for debugging)
--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--no-tqdm           # disable progress bars
--nb-workers        # number of workers for for data loading (default: 8)
--detect-anomaly    # detect autograd anomalies and terminate if encountered
--batch-size        # control dataloader batch size
--seed              # random seed (default: 12345)

Overall Evaluation

Compare results from many results files produced by main-test.py based on dataset ground truth:

python compute-correlations.py ground-truth.csv pred-1.csv ... pred-n.csv --names pred-1 ... pred-n

Configuration

Examples configurations for all experiments can be found here

We use toml files to define configurations. Each one consists of three sections:

  • [trainer]: configuration options for ptpt.TrainerConfig.
  • [data]: configuration options for the dataset.
  • [vqcpc] or [stoi]: configuration options for the VQCPC and predictor models respectively.

Checkpoints

Pretrained checkpoints for all models can be found here

Citation

TODO: add citation once paper published / arXiv-ed :)

Owner
Alex McKinney
Final-year student at Durham University. Interested in generative models and unsupervised representation learning.
Alex McKinney
Code for Transformers Solve Limited Receptive Field for Monocular Depth Prediction

Official PyTorch code for Transformers Solve Limited Receptive Field for Monocular Depth Prediction. Guanglei Yang, Hao Tang, Mingli Ding, Nicu Sebe,

stanley 152 Dec 16, 2022
PyTorch code for our ECCV 2018 paper "Image Super-Resolution Using Very Deep Residual Channel Attention Networks"

PyTorch code for our ECCV 2018 paper "Image Super-Resolution Using Very Deep Residual Channel Attention Networks"

Yulun Zhang 1.2k Dec 26, 2022
Optimizes image files by converting them to webp while also updating all references.

About Optimizes images by (re-)saving them as webp. For every file it replaced it automatically updates all references. Works on single files as well

Watermelon Wolverine 18 Dec 23, 2022
Large-Scale Unsupervised Object Discovery

Large-Scale Unsupervised Object Discovery Huy V. Vo, Elena Sizikova, Cordelia Schmid, Patrick Pérez, Jean Ponce [PDF] We propose a novel ranking-based

17 Sep 19, 2022
Chainer implementation of recent GAN variants

Chainer-GAN-lib This repository collects chainer implementation of state-of-the-art GAN algorithms. These codes are evaluated with the inception score

399 Oct 23, 2022
An exploration of log domain "alternative floating point" for hardware ML/AI accelerators.

This repository contains the SystemVerilog RTL, C++, HLS (Intel FPGA OpenCL to wrap RTL code) and Python needed to reproduce the numerical results in

Facebook Research 373 Dec 31, 2022
git《Self-Attention Attribution: Interpreting Information Interactions Inside Transformer》(AAAI 2021) GitHub:

Self-Attention Attribution This repository contains the implementation for AAAI-2021 paper Self-Attention Attribution: Interpreting Information Intera

60 Dec 29, 2022
Easy-to-use library to boost AI inference leveraging state-of-the-art optimization techniques.

NEW RELEASE How Nebullvm Works • Tutorials • Benchmarks • Installation • Get Started • Optimization Examples Discord | Website | LinkedIn | Twitter Ne

Nebuly 1.7k Dec 31, 2022
A PyTorch-based Semi-Supervised Learning (SSL) Codebase for Pixel-wise (Pixel) Vision Tasks

PixelSSL is a PyTorch-based semi-supervised learning (SSL) codebase for pixel-wise (Pixel) vision tasks. The purpose of this project is to promote the

Zhanghan Ke 255 Dec 11, 2022
Multi-task Self-supervised Object Detection via Recycling of Bounding Box Annotations (CVPR, 2019)

Multi-task Self-supervised Object Detection via Recycling of Bounding Box Annotations (CVPR 2019) To make better use of given limited labels, we propo

126 Sep 13, 2022
This repository contains a Ruby API for utilizing TensorFlow.

tensorflow.rb Description This repository contains a Ruby API for utilizing TensorFlow. Linux CPU Linux GPU PIP Mac OS CPU Not Configured Not Configur

somatic labs 825 Dec 26, 2022
Self-Supervised Deep Blind Video Super-Resolution

Self-Blind-VSR Paper | Discussion Self-Supervised Deep Blind Video Super-Resolution By Haoran Bai and Jinshan Pan Abstract Existing deep learning-base

Haoran Bai 35 Dec 09, 2022
End-to-end speech secognition toolkit

End-to-end speech secognition toolkit This is an E2E ASR toolkit modified from Espnet1 (version 0.9.9). This is the official implementation of paper:

Jinchuan Tian 147 Dec 28, 2022
Facial detection, landmark tracking and expression transfer library for Windows, Linux and Mac

Welcome to the CSIRO Face Analysis SDK. Documentation for the SDK can be found in doc/documentation.html. All code in this SDK is provided according t

Luiz Carlos Vieira 7 Jul 16, 2020
Compositional and Parameter-Efficient Representations for Large Knowledge Graphs

NodePiece - Compositional and Parameter-Efficient Representations for Large Knowledge Graphs NodePiece is a "tokenizer" for reducing entity vocabulary

Michael Galkin 107 Jan 04, 2023
Official Implementation of CoSMo: Content-Style Modulation for Image Retrieval with Text Feedback

CoSMo.pytorch Official Implementation of CoSMo: Content-Style Modulation for Image Retrieval with Text Feedback, Seungmin Lee*, Dongwan Kim*, Bohyung

Seung Min Lee 54 Dec 08, 2022
unofficial pytorch implementation of RefineGAN

RefineGAN unofficial pytorch implementation of RefineGAN (https://arxiv.org/abs/1709.00753) for CSMRI reconstruction, the official code using tensorpa

xinby17 5 Jul 21, 2022
Implementation of the state-of-the-art vision transformers with tensorflow

ViT Tensorflow This repository contains the tensorflow implementation of the state-of-the-art vision transformers (a category of computer vision model

Mohammadmahdi NouriBorji 2 Mar 16, 2022
chen2020iros: Learning an Overlap-based Observation Model for 3D LiDAR Localization.

Overlap-based 3D LiDAR Monte Carlo Localization This repo contains the code for our IROS2020 paper: Learning an Overlap-based Observation Model for 3D

Photogrammetry & Robotics Bonn 219 Dec 15, 2022
ViViT: Curvature access through the generalized Gauss-Newton's low-rank structure

ViViT is a collection of numerical tricks to efficiently access curvature from the generalized Gauss-Newton (GGN) matrix based on its low-rank structure. Provided functionality includes computing

Felix Dangel 12 Dec 08, 2022