PyTorch Implementation of Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation

Overview

StyleSpeech - PyTorch Implementation

PyTorch Implementation of Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation.

Status (2021.06.13)

  • StyleSpeech (naive branch)
  • Meta-StyleSpeech (main branch)

Quickstart

Dependencies

You can install the Python dependencies with

pip3 install -r requirements.txt

Inference

You have to download pretrained models and put them in output/ckpt/LibriTTS/.

For English single-speaker TTS, run

python3 synthesize.py --text "YOUR_DESIRED_TEXT" --ref_audio path/to/reference_audio.wav --restore_step 200000 --mode single -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml -t config/LibriTTS/train.yaml

The generated utterances will be put in output/result/. Your synthesized speech will have ref_audio's style.

Batch Inference

Batch inference is also supported, try

python3 synthesize.py --source preprocessed_data/LibriTTS/val.txt --restore_step 200000 --mode batch -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml -t config/LibriTTS/train.yaml

to synthesize all utterances in preprocessed_data/LibriTTS/val.txt. This can be viewed as a reconstruction of validation datasets referring to themselves for the reference style.

Controllability

The pitch/volume/speaking rate of the synthesized utterances can be controlled by specifying the desired pitch/energy/duration ratios. For example, one can increase the speaking rate by 20 % and decrease the volume by 20 % by

python3 synthesize.py --text "YOUR_DESIRED_TEXT" --restore_step 200000 --mode single -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml -t config/LibriTTS/train.yaml --duration_control 0.8 --energy_control 0.8

Note that the controllability is originated from FastSpeech2 and not a vital interest of StyleSpeech.

Training

Datasets

The supported datasets are

  • LibriTTS: a multi-speaker English dataset containing 585 hours of speech by 2456 speakers.
  • (will be added more)

Preprocessing

First, run

python3 prepare_align.py config/LibriTTS/preprocess.yaml

for some preparations.

In this implementation, Montreal Forced Aligner (MFA) is used to obtain the alignments between the utterances and the phoneme sequences.

Download the official MFA package and run

./montreal-forced-aligner/bin/mfa_align raw_data/LibriTTS/ lexicon/librispeech-lexicon.txt english preprocessed_data/LibriTTS

or

./montreal-forced-aligner/bin/mfa_train_and_align raw_data/LibriTTS/ lexicon/librispeech-lexicon.txt preprocessed_data/LibriTTS

to align the corpus and then run the preprocessing script.

python3 preprocess.py config/LibriTTS/preprocess.yaml

Training

Train your model with

python3 train.py -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml -t config/LibriTTS/train.yaml

As described in the paper, the script will start from pre-training the naive model until meta_learning_warmup steps and then meta-train the model for additional steps via episodic training.

TensorBoard

Use

tensorboard --logdir output/log/LibriTTS

to serve TensorBoard on your localhost.

Implementation Issues

  1. Use 22050Hz sampling rate instead of 16kHz.
  2. Add one fully connected layer at the beginning of Mel-Style Encoder to upsample input mel-spectrogram from 80 to 128.
  3. The model size including meta-learner is 28.197M.
  4. Use a maximum 16 batch size on training instead of 48 or 20 mainly due to the lack of memory capacity with a single 24GiB TITAN-RTX. This can be achieved by the following script to filter out data longer than max_seq_len:
    python3 filelist_filtering.py -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml
    
    This will generate train_filtered.txt in the same location of train.txt.
  5. Since the total batch size is decreased, the number of training steps is doubled compared to the original paper.
  6. Use HiFi-GAN instead of MelGAN for vocoding.

Citation

@misc{lee2021stylespeech,
  author = {Lee, Keon},
  title = {StyleSpeech},
  year = {2021},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/keonlee9420/StyleSpeech}}
}

References

Comments
  • What is the perfermance compared with Adaspeech

    What is the perfermance compared with Adaspeech

    Thank you for your great work and share. Your work looks differ form adaspeech and NAUTILUS. You use GANs which i did not see in other papers regarding adaptative TTS. Have you compare this method with adaspeech1/2? how about the mos and similarity?

    opened by Liujingxiu23 10
  • The size of tensor a (xx) must match the size of tensor b (yy)

    The size of tensor a (xx) must match the size of tensor b (yy)

    Hi I try to run your project. I use cuda 10.1, all requirements are installed (with torch 1.8.1), all models are preloaded. But i have an error: python3 synthesize.py --text "Hello world" --restore_step 200000 --mode single -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml -t config/LibriTTS/train.yaml --duration_control 0.8 --energy_control 0.8 --ref_audio ref.wav

    Removing weight norm...
    Raw Text Sequence: Hello world
    Phoneme Sequence: {HH AH0 L OW1 W ER1 L D}
    Traceback (most recent call last):
      File "synthesize.py", line 268, in <module>
        synthesize(model, args.restore_step, configs, vocoder, batchs, control_values)
      File "synthesize.py", line 152, in synthesize
        d_control=duration_control
      File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
        result = self.forward(input, *kwargs)
      File "/usr/local/work/model/StyleSpeech.py", line 144, in forward
        d_control,
      File "/usr/local/work/model/StyleSpeech.py", line 91, in G
        output, mel_masks = self.mel_decoder(output, style_vector, mel_masks)
      File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
        result = self.forward(input, kwargs)
      File "/usr/local/work/model/modules.py", line 307, in forward
        enc_seq = self.mel_prenet(enc_seq, mask)
      File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
        result = self.forward(input, kwargs)
      File "/usr/local/work/model/modules.py", line 259, in forward
        x = x.masked_fill(mask.unsqueeze(-1), 0)
    RuntimeError: The size of tensor a (44) must match the size of tensor b (47) at non-singleton dimension 1
    
    opened by DiDimus 9
  • VCTK datasets

    VCTK datasets

    Hi, I note your paper evaluates the models' performance on VCTK datasets, but I not see the process file about VCTK. Hence, could you share the files, thank you very much.

    opened by XXXHUA 7
  • training error

    training error

    Thanks for your sharing!

    I tried both naive and main branches using your checkpoints, it seems the former one is much better. So I trained AISHELL3 models with small changes on your code and the synthesized waves are good for me.

    However when I add my own data into AISHELL3, some error occurred: Training: 0%| | 3105/900000 [32:05<154:31:49, 1.61it/s] Epoch 2: 69%|██████████████████████▏ | 318/459 [05:02<02:14, 1.05it/s] File "train.py", line 211, in main(args, configs) File "train.py", line 87, in main output = model(*(batch[2:])) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 165, in forward return self.module(*inputs[0], **kwargs[0]) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/StyleSpeech-naive/model/StyleSpeech.py", line 83, in forward ) = self.variance_adaptor( File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/StyleSpeech-naive/model/modules.py", line 404, in forward x = x + pitch_embedding RuntimeError: The size of tensor a (52) must match the size of tensor b (53) at non-singleton dimension 1

    I only replaced two speakers and preprocessed data the same as the in readme.

    Do you have any advice for this error ? Any suggestion is appreciated.

    opened by MingZJU 6
  • the synthesis result is bad when using pretrain model

    the synthesis result is bad when using pretrain model

    hello sir, thanks for your sharing.

    i meet a problem when i using pretrain model to synthsize demo file. the effect of synthesized wav is so bad.

    do you konw what problem happened?

    pretrain_model: output/ckpt/LibriTTS_meta_learner/200000.pth.tar ref_audio: ref_audio.zip demo_txt: {Promises are often like the butterfly, which disappear after beautiful hover. No matter the ending is perfect or not, you cannot disappear from my world.} demo_wav:demo.zip

    opened by mnfutao 4
  • Maybe style_prototype can instead of ref_mel?

    Maybe style_prototype can instead of ref_mel?

    hello @keonlee9420 , thanks for your contribution on StyleSpeech. When I read your paper and source code, I think that the style_prototype (which is an embedding matrix) maybe can instread of the ref_mel, because there is a CE-loss between style_prototype and style_vector, which can control this embedding matrix close to style. In short, we can give a speaker id to synthesize this speaker's wave. Is it right?

    opened by forwiat 3
  • architecture shows bad results

    architecture shows bad results

    Hi, i have completely repeated your steps for learning. During training, style speech loss fell down, but after learning began, meta style speech loss began to grow up. Can you help with training the model? I can describe my steps in more detail.

    opened by e0xextazy 2
  • UnboundLocalError: local variable 'pitch' referenced before assignment

    UnboundLocalError: local variable 'pitch' referenced before assignment

    Hi, when I run preprocessor.py, I have this problem: /preprocessor.py", line 92, in build_from_path if len(pitch) > 0: UnboundLocalError: local variable 'pitch' referenced before assignment When I try to add a global declaration to the function, it shows NameError: name 'pitch' is not defined How should this be resolved? I would be grateful if I could get your guidance soon.

    opened by Summerxu86 0
  • How can I improve the synthesized results?

    How can I improve the synthesized results?

    I have trained the model for 200k steps, and still, the synthesised results are extremely bad. loss_curve This is what my loss curve looks like. Can you help me with what can I do now to improve my synthesized audio results?

    opened by sanjeevani279 1
  • RuntimeError: Error(s) in loading state_dict for Stylespeech

    RuntimeError: Error(s) in loading state_dict for Stylespeech

    Hi @keonlee9420, I am getting the following error, while running the naive branch :

    Traceback (most recent call last):
      File "synthesize.py", line 242, in <module>
        model = get_model(args, configs, device, train=False)
      File "/home/azureuser/aditya_workspace/stylespeech_keonlee_naive/utils/model.py", line 21, in get_model
        model.load_state_dict(ckpt["model"], strict=True)
      File "/home/azureuser/aditya_workspace/keonlee/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
        raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
    RuntimeError: Error(s) in loading state_dict for StyleSpeech:
    	Missing key(s) in state_dict: "D_t.mel_linear.0.fc_layer.fc_layer.linear.weight_orig", "D_t.mel_linear.0.fc_layer.fc_layer.linear.weight", "D_t.mel_linear.0.fc_layer.fc_layer.linear.weight_u", "D_t.mel_linear.0.fc_layer.fc_layer.linear.weight_orig", "D_t.mel_linear.0.fc_layer.fc_layer.linear.weight_u", "D_t.mel_linear.0.fc_layer.fc_layer.linear.weight_v", "D_t.mel_linear.1.fc_layer.fc_layer.linear.weight_orig", "D_t.mel_linear.1.fc_layer.fc_layer.linear.weight", "D_t.mel_linear.1.fc_layer.fc_layer.linear.weight_u", "D_t.mel_linear.1.fc_layer.fc_layer.linear.weight_orig", "D_t.mel_linear.1.fc_layer.fc_layer.linear.weight_u", "D_t.mel_linear.1.fc_layer.fc_layer.linear.weight_v", "D_t.discriminator_stack.0.fc_layer.fc_layer.linear.weight_orig", "D_t.discriminator_stack.0.fc_layer.fc_layer.linear.weight", "D_t.discriminator_stack.0.fc_layer.fc_layer.linear.weight_u", "D_t.discriminator_stack.0.fc_layer.fc_layer.linear.weight_orig", "D_t.discriminator_stack.0.fc_layer.fc_layer.linear.weight_u", "D_t.discriminator_stack.0.fc_layer.fc_layer.linear.weight_v", "D_t.discriminator_stack.1.fc_layer.fc_layer.linear.weight_orig", "D_t.discriminator_stack.1.fc_layer.fc_layer.linear.weight", "D_t.discriminator_stack.1.fc_layer.fc_layer.linear.weight_u", "D_t.discriminator_stack.1.fc_layer.fc_layer.linear.weight_orig", "D_t.discriminator_stack.1.fc_layer.fc_layer.linear.weight_u", "D_t.discriminator_stack.1.fc_layer.fc_layer.linear.weight_v", "D_t.discriminator_stack.2.fc_layer.fc_layer.linear.weight_orig", "D_t.discriminator_stack.2.fc_layer.fc_layer.linear.weight", "D_t.discriminator_stack.2.fc_layer.fc_layer.linear.weight_u", "D_t.discriminator_stack.2.fc_layer.fc_layer.linear.weight_orig", "D_t.discriminator_stack.2.fc_layer.fc_layer.linear.weight_u", "D_t.discriminator_stack.2.fc_layer.fc_layer.linear.weight_v", "D_t.final_linear.fc_layer.fc_layer.linear.weight_orig", "D_t.final_linear.fc_layer.fc_layer.linear.weight", "D_t.final_linear.fc_layer.fc_layer.linear.weight_u", "D_t.final_linear.fc_layer.fc_layer.linear.weight_orig", "D_t.final_linear.fc_layer.fc_layer.linear.weight_u", "D_t.final_linear.fc_layer.fc_layer.linear.weight_v", "D_s.fc_1.fc_layer.fc_layer.linear.weight_orig", "D_s.fc_1.fc_layer.fc_layer.linear.weight", "D_s.fc_1.fc_layer.fc_layer.linear.weight_u", "D_s.fc_1.fc_layer.fc_layer.linear.weight_orig", "D_s.fc_1.fc_layer.fc_layer.linear.weight_u", "D_s.fc_1.fc_layer.fc_layer.linear.weight_v", "D_s.spectral_stack.0.fc_layer.fc_layer.linear.weight_orig", "D_s.spectral_stack.0.fc_layer.fc_layer.linear.weight", "D_s.spectral_stack.0.fc_layer.fc_layer.linear.weight_u", "D_s.spectral_stack.0.fc_layer.fc_layer.linear.weight_orig", "D_s.spectral_stack.0.fc_layer.fc_layer.linear.weight_u", "D_s.spectral_stack.0.fc_layer.fc_layer.linear.weight_v", "D_s.spectral_stack.1.fc_layer.fc_layer.linear.weight_orig", "D_s.spectral_stack.1.fc_layer.fc_layer.linear.weight", "D_s.spectral_stack.1.fc_layer.fc_layer.linear.weight_u", "D_s.spectral_stack.1.fc_layer.fc_layer.linear.weight_orig", "D_s.spectral_stack.1.fc_layer.fc_layer.linear.weight_u", "D_s.spectral_stack.1.fc_layer.fc_layer.linear.weight_v", "D_s.temporal_stack.0.conv_layer.conv_layer.conv.weight_orig", "D_s.temporal_stack.0.conv_layer.conv_layer.conv.weight", "D_s.temporal_stack.0.conv_layer.conv_layer.conv.weight_u", "D_s.temporal_stack.0.conv_layer.conv_layer.conv.bias", "D_s.temporal_stack.0.conv_layer.conv_layer.conv.weight_orig", "D_s.temporal_stack.0.conv_layer.conv_layer.conv.weight_u", "D_s.temporal_stack.0.conv_layer.conv_layer.conv.weight_v", "D_s.temporal_stack.1.conv_layer.conv_layer.conv.weight_orig", "D_s.temporal_stack.1.conv_layer.conv_layer.conv.weight", "D_s.temporal_stack.1.conv_layer.conv_layer.conv.weight_u", "D_s.temporal_stack.1.conv_layer.conv_layer.conv.bias", "D_s.temporal_stack.1.conv_layer.conv_layer.conv.weight_orig", "D_s.temporal_stack.1.conv_layer.conv_layer.conv.weight_u", "D_s.temporal_stack.1.conv_layer.conv_layer.conv.weight_v", "D_s.slf_attn_stack.0.w_qs.linear.weight_orig", "D_s.slf_attn_stack.0.w_qs.linear.weight", "D_s.slf_attn_stack.0.w_qs.linear.weight_u", "D_s.slf_attn_stack.0.w_qs.linear.weight_orig", "D_s.slf_attn_stack.0.w_qs.linear.weight_u", "D_s.slf_attn_stack.0.w_qs.linear.weight_v", "D_s.slf_attn_stack.0.w_ks.linear.weight_orig", "D_s.slf_attn_stack.0.w_ks.linear.weight", "D_s.slf_attn_stack.0.w_ks.linear.weight_u", "D_s.slf_attn_stack.0.w_ks.linear.weight_orig", "D_s.slf_attn_stack.0.w_ks.linear.weight_u", "D_s.slf_attn_stack.0.w_ks.linear.weight_v", "D_s.slf_attn_stack.0.w_vs.linear.weight_orig", "D_s.slf_attn_stack.0.w_vs.linear.weight", "D_s.slf_attn_stack.0.w_vs.linear.weight_u", "D_s.slf_attn_stack.0.w_vs.linear.weight_orig", "D_s.slf_attn_stack.0.w_vs.linear.weight_u", "D_s.slf_attn_stack.0.w_vs.linear.weight_v", "D_s.slf_attn_stack.0.layer_norm.weight", "D_s.slf_attn_stack.0.layer_norm.bias", "D_s.slf_attn_stack.0.fc.linear.weight_orig", "D_s.slf_attn_stack.0.fc.linear.weight", "D_s.slf_attn_stack.0.fc.linear.weight_u", "D_s.slf_attn_stack.0.fc.linear.weight_orig", "D_s.slf_attn_stack.0.fc.linear.weight_u", "D_s.slf_attn_stack.0.fc.linear.weight_v", "D_s.fc_2.fc_layer.fc_layer.linear.weight_orig", "D_s.fc_2.fc_layer.fc_layer.linear.weight", "D_s.fc_2.fc_layer.fc_layer.linear.weight_u", "D_s.fc_2.fc_layer.fc_layer.linear.weight_orig", "D_s.fc_2.fc_layer.fc_layer.linear.weight_u", "D_s.fc_2.fc_layer.fc_layer.linear.weight_v", "D_s.V.fc_layer.fc_layer.linear.weight", "D_s.w_b_0.fc_layer.fc_layer.linear.weight", "D_s.w_b_0.fc_layer.fc_layer.linear.bias", "style_prototype.weight".
    	Unexpected key(s) in state_dict: "speaker_emb.weight".
    

    Can you help with this, seems like the pre-trained weights are old and do not conform to the current architecture.

    opened by sirius0503 1
  • time dimension doesn't match

    time dimension doesn't match

    ^MTraining: 0%| | 0/200000 [00:00<?, ?it/s] ^MEpoch 1: 0%| | 0/454 [00:00<?, ?it/s]^[[APrepare training ... Number of StyleSpeech Parameters: 28197333 Removing weight norm... Traceback (most recent call last): File "train.py", line 224, in main(args, configs) File "train.py", line 98, in main output = (None, None, model((batch[2:-5]))) File "/share/mini1/sw/std/python/anaconda3-2019.07/v3.7/envs/StyleSpeech/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/share/mini1/sw/std/python/anaconda3-2019.07/v3.7/envs/StyleSpeech/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 165, in forward return self.module(*inputs[0], **kwargs[0]) File "/share/mini1/sw/std/python/anaconda3-2019.07/v3.7/envs/StyleSpeech/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/share/mini1/res/t/vc/studio/timap-en/libritts/StyleSpeech/model/StyleSpeech.py", line 144, in forward d_control, File "/share/mini1/res/t/vc/studio/timap-en/libritts/StyleSpeech/model/StyleSpeech.py", line 88, in G d_control, File "/share/mini1/sw/std/python/anaconda3-2019.07/v3.7/envs/StyleSpeech/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/share/mini1/res/t/vc/studio/timap-en/libritts/StyleSpeech/model/modules.py", line 417, in forward x = x + pitch_embedding RuntimeError: The size of tensor a (132) must match the size of tensor b (130) at non-singleton dimension 1 ^MTraining: 0%| | 1/200000 [00:02<166:02:12, 2.99s/it]

    I think it might because of mfa I used. As mentioned in https://montreal-forced-aligner.readthedocs.io/en/latest/getting_started.html, I installed mfa through conda.

    Then I used mfa align raw_data/LibriTTS lexicon/librispeech-lexicon.txt english preprocessed_data/LibriTTS instead of the way you showed. But I can't find a way to run it as the way you showed, because I installed mfa through conda.

    opened by MingjieChen 24
Releases(v1.0.2)
Owner
Keon Lee
Expressive Speech Synthesis | Conversational AI | Open-domain Dialog | NLP | Generative Models | Empathic Computing | HCI
Keon Lee
Histology images query (unsupervised)

110-1-NTU-DBME5028-Histology-images-query Final Project: Histology images query (unsupervised) Kaggle: https://www.kaggle.com/c/histology-images-query

1 Jan 05, 2022
Image reconstruction done with untrained neural networks.

PyTorch Deep Image Prior An implementation of image reconstruction methods from Deep Image Prior (Ulyanov et al., 2017) in PyTorch. The point of the p

Atiyo Ghosh 192 Nov 30, 2022
PyZebrascope - an open-source Python platform for brain-wide neural activity imaging in behaving zebrafish

PyZebrascope - an open-source Python platform for brain-wide neural activity imaging in behaving zebrafish

1 May 31, 2022
Generate fine-tuning samples & Fine-tuning the model & Generate samples by transferring Note On

UPMT Generate fine-tuning samples & Fine-tuning the model & Generate samples by transferring Note On See main.py as an example: from model import PopM

7 Sep 01, 2022
MINOS: Multimodal Indoor Simulator

MINOS Simulator MINOS is a simulator designed to support the development of multisensory models for goal-directed navigation in complex indoor environ

194 Dec 27, 2022
AAAI 2022 paper - Unifying Model Explainability and Robustness for Joint Text Classification and Rationale Extraction

AT-BMC Unifying Model Explainability and Robustness for Joint Text Classification and Rationale Extraction (AAAI 2022) Paper Prerequisites Install pac

16 Nov 26, 2022
Tensorflow Tutorials using Jupyter Notebook

Tensorflow Tutorials using Jupyter Notebook TensorFlow tutorials written in Python (of course) with Jupyter Notebook. Tried to explain as kindly as po

Sungjoon 2.6k Dec 22, 2022
Diverse graph algorithms implemented using JGraphT library.

# 1. Installing Maven & Pandas First, please install Java (JDK11) and Python 3 if they are not already. Next, make sure that Maven (for importing J

See Woo Lee 3 Dec 17, 2022
Making a music video with Wav2CLIP and VQGAN-CLIP

music2video Overview A repo for making a music video with Wav2CLIP and VQGAN-CLIP. The base code was derived from VQGAN-CLIP The CLIP embedding for au

Joel Jang | 장요엘 163 Dec 26, 2022
Taming Transformers for High-Resolution Image Synthesis

Taming Transformers for High-Resolution Image Synthesis CVPR 2021 (Oral) Taming Transformers for High-Resolution Image Synthesis Patrick Esser*, Robin

CompVis Heidelberg 3.5k Jan 03, 2023
CondenseNet: Light weighted CNN for mobile devices

CondenseNets This repository contains the code (in PyTorch) for "CondenseNet: An Efficient DenseNet using Learned Group Convolutions" paper by Gao Hua

Shichen Liu 690 Nov 30, 2022
Nested Graph Neural Network (NGNN) is a general framework to improve a base GNN's expressive power and performance

Nested Graph Neural Networks About Nested Graph Neural Network (NGNN) is a general framework to improve a base GNN's expressive power and performance.

Muhan Zhang 38 Jan 05, 2023
This repository contains a PyTorch implementation of the paper Learning to Assimilate in Chaotic Dynamical Systems.

Amortized Assimilation This repository contains a PyTorch implementation of the paper Learning to Assimilate in Chaotic Dynamical Systems. Abstract: T

4 Aug 16, 2022
Official PyTorch Implementation of "AgentFormer: Agent-Aware Transformers for Socio-Temporal Multi-Agent Forecasting".

AgentFormer This repo contains the official implementation of our paper: AgentFormer: Agent-Aware Transformers for Socio-Temporal Multi-Agent Forecast

Ye Yuan 161 Dec 23, 2022
The repo for reproducing Seed-driven Document Ranking for Systematic Reviews: A Reproducibility Study

ECIR Reproducibility Paper: Seed-driven Document Ranking for Systematic Reviews: A Reproducibility Study This code corresponds to the reproducibility

ielab 3 Mar 31, 2022
Implicit Model Specialization through DAG-based Decentralized Federated Learning

Federated Learning DAG Experiments This repository contains software artifacts to reproduce the experiments presented in the Middleware '21 paper "Imp

Operating Systems and Middleware Group 5 Oct 16, 2022
Pytorch codes for "Self-supervised Multi-view Stereo via Effective Co-Segmentation and Data-Augmentation"

Self-Supervised-MVS This repository is the official PyTorch implementation of our AAAI 2021 paper: "Self-supervised Multi-view Stereo via Effective Co

hongbin_xu 127 Jan 04, 2023
ACL'2021: LM-BFF: Better Few-shot Fine-tuning of Language Models

LM-BFF (Better Few-shot Fine-tuning of Language Models) This is the implementation of the paper Making Pre-trained Language Models Better Few-shot Lea

Princeton Natural Language Processing 607 Jan 07, 2023
Safe Bayesian Optimization

SafeOpt - Safe Bayesian Optimization This code implements an adapted version of the safe, Bayesian optimization algorithm, SafeOpt [1], [2]. It also p

Felix Berkenkamp 111 Dec 11, 2022
An abstraction layer for mathematical optimization solvers.

MathOptInterface Documentation Build Status Social An abstraction layer for mathematical optimization solvers. Replaces MathProgBase. Citing MathOptIn

JuMP-dev 284 Jan 04, 2023