Generative Query Network (GQN) in PyTorch as described in "Neural Scene Representation and Rendering"

Overview

Update 2019/06/24: A model trained on 10% of the Shepard-Metzler dataset has been added, the following notebook explains the main features of this model: nbviewer

Generative Query Network

This is a PyTorch implementation of the Generative Query Network (GQN) described in the DeepMind paper "Neural scene representation and rendering" by Eslami et al. For an introduction to the model and problem described in the paper look at the article by DeepMind.

The current implementation generalises to any of the datasets described in the paper. However, currently, only the Shepard-Metzler dataset has been implemented. To use this dataset you can use the provided script in

sh scripts/data.sh data-dir batch-size

The model can be trained in full by in accordance to the paper by running the file run-gqn.py or by using the provided training script

sh scripts/gpu.sh data-dir

Implementation

The implementation shown in this repository consists of all of the representation architectures described in the paper along with the generative model that is similar to the one described in "Towards conceptual compression" by Gregor et al.

Additionally, this repository also contains implementations of the DRAW model and the ConvolutionalDRAW model both described by Gregor et al.

Comments
  • Training time and testing demo

    Training time and testing demo

    Hi Jesper,

    Thank you for your great code of gqn in real image, I am a little curious about the following issues: How many epochs it use to train a model on real image? How many training data do you use (percentage of full training dataset)? Can you show a testing demo?

    Thank you very much!

    Best wishes, Mingjia Chen

    opened by mjchen611 22
  • ConvLSTM did not concat hidden from last round

    ConvLSTM did not concat hidden from last round

    In the structure presented in the paper, the hidden from last round is concat with input and then proceed for other operation. But it seems your LSTM did not use the hidden information from previous round.

    opened by Tom-the-Cat 7
  • Bad images in training

    Bad images in training

    While playing around with the sm5 dataset, I noticed some of them are badly rendered. individualimage Not sure if this will pose any problem for training, just wanted to point this out.

    opened by versatran01 7
  • Question about generator

    Question about generator

    In the top docstring of generator.py, you mentioned that

    The inference-generator architecture is conceptually
    similar to the encoder-decoder pair seen in variational
    autoencoders.
    

    I don't quite understand this part and I would really appreciate if you could explain a bit or point me at some related aritcles. For the generator I can see how it is similar to a decoder, where it takes latent z, query viewpoint v, and aggregated representation r and eventually output the image x_mu.

    But I'm a bit confused by the inference being the conterpart of encoder.

    opened by versatran01 7
  • Loss Change

    Loss Change

    Dear wohlert,

    May I consult you several questions?

    1. I tried to train this network on Mazes Data from https://github.com/deepmind/gqn-datasets. Actually it just contains 5% data, which is around 110000, instead of the full data. Is it right?

    2. I trained 30000 steps, but the elbo loss only converged to 6800 which has a big difference compared to around 7 in the supplementary. So may I ask what is the approximate value do you achieve on the data you used?

    3. From the visualisation based on Question 2, the reconstruction seems to be reasonable. But the sampling results is quite bad. Do you meet the same problem?

    Many thanks, Bing

    opened by BingCS 5
  • Questions on data preparing

    Questions on data preparing

    Hi, Wohlert:

    After the data conversion with your scripts, I visualize some of the images in the *.pt found pictures like this Figure_1-1

    What's wrong with that Also I'm confused about your batch operation , say if you batch the sequences as you convert them, does it mean that you won't batch them again when use dataloader?

    Thanks

    opened by Kyridiculous2 5
  • Training crashes at the same spot for both Shepard Metzler datasets

    Training crashes at the same spot for both Shepard Metzler datasets

    Some context:

    • I downloaded and converted the datasets via data.sh and set batch size to 12. Note that I am using TensorFlow 1.14 for reading the tfrecord files and converting them.
    • I use gpu.sh to run the training script. I set the batch size to either of [1,12,36,72] and DataParallel to True to use 4 GPUs

    But after a shrot time I get the following errors if I use any batch size higher than 1. This happens on iterations 40, 13 and 6 with batch sizes 12, 36 and 72. This happens for both Shepard Metzler datasets. Why I am getting these errors? Does batch size 1 on the training code mean reading one of the .pt.gz files? If so, setting batch size to 1 in the training script should actually mean 12. Would that be correct?

    Here's what I get for the data set with 5 parts when I set batch size to 36 for instance:

    Epoch [1/200]: [13/1856]   1%|▊                                                                                                                       , elbo=-2.1e+4, kl=827, mu=5e-6, sigma=2 [00:21<52:34]Current run is terminating due to exception: Caught RuntimeError in DataLoader worker process 13.
    Original Traceback (most recent call last):
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 8 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
    .
    Engine run is terminating due to exception: Caught RuntimeError in DataLoader worker process 13.
    Original Traceback (most recent call last):
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 8 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
    .
    Traceback (most recent call last):
      File "../run-gqn.py", line 183, in <module>
        trainer.run(train_loader, args.n_epochs)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 850, in run
        return self._internal_run()
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 952, in _internal_run
        self._handle_exception(e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 714, in _handle_exception
        self._fire_event(Events.EXCEPTION_RAISED, e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 607, in _fire_event
        func(self, *(event_args + args), **kwargs)
      File "../run-gqn.py", line 181, in handle_exception
        else: raise e
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 937, in _internal_run
        hours, mins, secs = self._run_once_on_dataset()
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 705, in _run_once_on_dataset
        self._handle_exception(e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 714, in _handle_exception
        self._fire_event(Events.EXCEPTION_RAISED, e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 607, in _fire_event
        func(self, *(event_args + args), **kwargs)
      File "../run-gqn.py", line 181, in handle_exception
        else: raise e
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 655, in _run_once_on_dataset
        batch = next(self._dataloader_iter)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 801, in __next__
        return self._process_data(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 846, in _process_data
        data.reraise()
      File "/usr/local/lib/python3.6/dist-packages/torch/_utils.py", line 385, in reraise
        raise self.exc_type(msg)
    RuntimeError: Caught RuntimeError in DataLoader worker process 13.
    Original Traceback (most recent call last):
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 8 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
    
    opened by Amir-Arsalan 4
  • AttributeError: 'int' object has no attribute 'size'

    AttributeError: 'int' object has no attribute 'size'

    in draw.py, I get this error at the line 118 (batch_size = z.size(0)) Sorry if this is obvious, thanks for help anyway.

    ~ % pip show torch :( Name: torch Version: 1.0.1.post2

    opened by DRM-Free 4
  • Increase dimension of viewpoint and representation

    Increase dimension of viewpoint and representation

    Thanks for this implementation. One question I have is when increasing the dimension of viewpoint and representation, you use torch.repeat. Is there any reason for this? Can one possibly use interpolate?

    In the original paper it says "when concatenating viewpoint v to an image or feature map, its values are ‘broadcast’ in the spatial dimensions to obtain the correct size. "

    The word 'broadcast' is not precisely defined, hence the question.

    opened by versatran01 4
  • Learning rate change

    Learning rate change

    Regarding line 113 of run-gqn.py. Does this change the learning rate of the Adam optimizer? This post shows something different

    https://stackoverflow.com/questions/48324152/pytorch-how-to-change-the-learning-rate-of-an-optimizer-at-any-given-moment-no

    opened by david-bernstein 4
  • Using the rooms data?

    Using the rooms data?

    I wanted to try your code on the rooms data but during conversion, I get these errors. What could I be doing wrong? Note that for the rooms data with moving camera I set the number of camera parameters to 7:

    Traceback (most recent call last):
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 119, in worker
        result = (True, func(*args, **kwds))
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 44, in mapstar
        return list(map(*args))
      File "tfrecord-converter.py", line 66, in convert
        for i, batch in enumerate(batch_process(record)):
      File "tfrecord-converter.py", line 29, in chunk
        for first in iterator:
      File "tfrecord-converter.py", line 40, in process
        'cameras': tf.FixedLenFeature(shape=SEQ_DIM * POSE_DIM, dtype=tf.float32)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 1019, in parse_single_example
        serialized, features, example_names, name
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 1063, in parse_single_example_v2_unoptimized
        return parse_single_example_v2(serialized, features, name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 2089, in parse_single_example_v2
        dense_defaults, dense_shapes, name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 2206, in _parse_single_example_v2_raw
        name=name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_parsing_ops.py", line 1164, in parse_single_example
        ctx=_ctx)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_parsing_ops.py", line 1260, in parse_single_example_eager_fallback
        attrs=_attrs, ctx=_ctx, name=name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py", line 67, in quick_execute
        six.raise_from(core._status_to_exception(e.code, message), None)
      File "<string>", line 3, in raise_from
    tensorflow.python.framework.errors_impl.InvalidArgumentError: Key: frames.  Can't parse serialized Example. [Op:ParseSingleExample]
    """
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "tfrecord-converter.py", line 98, in <module>
        pool.map(f, records)
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 266, in map
        return self._map_async(func, iterable, mapstar, chunksize).get()
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 644, in get
        raise self._value
    tensorflow.python.framework.errors_impl.InvalidArgumentError: Key: frames.  Can't parse serialized Example. [Op:ParseSingleExample]
    
    opened by Amir-Arsalan 3
Releases(0.1)
Owner
Jesper Wohlert
Jesper Wohlert
[CVPR 2020] 3D Photography using Context-aware Layered Depth Inpainting

[CVPR 2020] 3D Photography using Context-aware Layered Depth Inpainting [Paper] [Project Website] [Google Colab] We propose a method for converting a

Virginia Tech Vision and Learning Lab 6.2k Jan 01, 2023
MCMC samplers for Bayesian estimation in Python, including Metropolis-Hastings, NUTS, and Slice

Sampyl May 29, 2018: version 0.3 Sampyl is a package for sampling from probability distributions using MCMC methods. Similar to PyMC3 using theano to

Mat Leonard 304 Dec 25, 2022
J.A.R.V.I.S is an AI virtual assistant made in python.

J.A.R.V.I.S is an AI virtual assistant made in python. Running JARVIS Without Python To run JARVIS without python: 1. Head over to our installation pa

somePythonProgrammer 16 Dec 29, 2022
Implementation of "Scaled-YOLOv4: Scaling Cross Stage Partial Network" using PyTorch framwork.

YOLOv4-large This is the implementation of "Scaled-YOLOv4: Scaling Cross Stage Partial Network" using PyTorch framwork. YOLOv4-CSP YOLOv4-tiny YOLOv4-

Kin-Yiu, Wong 2k Jan 02, 2023
Repository for code and dataset for our EMNLP 2021 paper - “So You Think You’re Funny?”: Rating the Humour Quotient in Standup Comedy.

AI-OpenMic Dataset The dataset is available for download via the follwing link. Repository for code and dataset for our EMNLP 2021 paper - “So You Thi

6 Oct 26, 2022
Social Fabric: Tubelet Compositions for Video Relation Detection

Social-Fabric Social Fabric: Tubelet Compositions for Video Relation Detection This repository contains the code and results for the following paper:

Shuo Chen 7 Aug 09, 2022
Keeper for Ricochet Protocol, implemented with Apache Airflow

Ricochet Keeper This repository contains Apache Airflow DAGs for executing keeper operations for Ricochet Exchange. Usage You will need to run this us

Ricochet Exchange 5 May 24, 2022
PyTorch implementation of Towards Accurate Alignment in Real-time 3D Hand-Mesh Reconstruction (ICCV 2021).

Towards Accurate Alignment in Real-time 3D Hand-Mesh Reconstruction Introduction This is official PyTorch implementation of Towards Accurate Alignment

TANG Xiao 96 Dec 27, 2022
Official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer"

[AAAI2022] UCTransNet This repo is the official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspectiv

Haonan Wang 199 Jan 03, 2023
A Research-oriented Federated Learning Library and Benchmark Platform for Graph Neural Networks. Accepted to ICLR'2021 - DPML and MLSys'21 - GNNSys workshops.

FedGraphNN: A Federated Learning System and Benchmark for Graph Neural Networks A Research-oriented Federated Learning Library and Benchmark Platform

FedML-AI 175 Dec 01, 2022
Sub-tomogram-Detection - Deep learning based model for Cyro ET Sub-tomogram-Detection

Deep learning based model for Cyro ET Sub-tomogram-Detection High degree of stru

Siddhant Kumar 2 Feb 04, 2022
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Christoph Reich 23 Sep 21, 2022
MoCoPnet - Deformable 3D Convolution for Video Super-Resolution

MoCoPnet: Exploring Local Motion and Contrast Priors for Infrared Small Target Super-Resolution Pytorch implementation of local motion and contrast pr

Xinyi Ying 28 Dec 15, 2022
CBREN: Convolutional Neural Networks for Constant Bit Rate Video Quality Enhancement

CBREN This is the Pytorch implementation for our IEEE TCSVT paper : CBREN: Convolutional Neural Networks for Constant Bit Rate Video Quality Enhanceme

Zhao Hengrun 3 Nov 04, 2022
Establishing Strong Baselines for TripClick Health Retrieval; ECIR 2022

TripClick Baselines with Improved Training Data Welcome 🙌 to the hub-repo of our paper: Establishing Strong Baselines for TripClick Health Retrieval

Sebastian Hofstätter 3 Nov 03, 2022
Anagram Generator in Python

Anagrams Generator This is a program for computing multiword anagrams. It makes no effort to come up with sentences that make sense; it only finds ana

Day Fundora 5 Nov 17, 2022
Trajectory Extraction of road users via Traffic Camera

Traffic Monitoring Citation The associated paper for this project will be published here as soon as possible. When using this software, please cite th

Julian Strosahl 14 Dec 17, 2022
Code accompanying paper: Meta-Learning to Improve Pre-Training

Meta-Learning to Improve Pre-Training This folder contains code to run experiments in the paper Meta-Learning to Improve Pre-Training, NeurIPS 2021. P

28 Dec 31, 2022
A tensorflow implementation of Fully Convolutional Networks For Semantic Segmentation

##A tensorflow implementation of Fully Convolutional Networks For Semantic Segmentation. #USAGE To run the trained classifier on some images: python w

Alex Seewald 13 Nov 17, 2022
Revisiting Global Statistics Aggregation for Improving Image Restoration

Revisiting Global Statistics Aggregation for Improving Image Restoration Xiaojie Chu, Liangyu Chen, Chengpeng Chen, Xin Lu Paper: https://arxiv.org/pd

MEGVII Research 128 Dec 24, 2022