Generative Query Network (GQN) in PyTorch as described in "Neural Scene Representation and Rendering"

Overview

Update 2019/06/24: A model trained on 10% of the Shepard-Metzler dataset has been added, the following notebook explains the main features of this model: nbviewer

Generative Query Network

This is a PyTorch implementation of the Generative Query Network (GQN) described in the DeepMind paper "Neural scene representation and rendering" by Eslami et al. For an introduction to the model and problem described in the paper look at the article by DeepMind.

The current implementation generalises to any of the datasets described in the paper. However, currently, only the Shepard-Metzler dataset has been implemented. To use this dataset you can use the provided script in

sh scripts/data.sh data-dir batch-size

The model can be trained in full by in accordance to the paper by running the file run-gqn.py or by using the provided training script

sh scripts/gpu.sh data-dir

Implementation

The implementation shown in this repository consists of all of the representation architectures described in the paper along with the generative model that is similar to the one described in "Towards conceptual compression" by Gregor et al.

Additionally, this repository also contains implementations of the DRAW model and the ConvolutionalDRAW model both described by Gregor et al.

Comments
  • Training time and testing demo

    Training time and testing demo

    Hi Jesper,

    Thank you for your great code of gqn in real image, I am a little curious about the following issues: How many epochs it use to train a model on real image? How many training data do you use (percentage of full training dataset)? Can you show a testing demo?

    Thank you very much!

    Best wishes, Mingjia Chen

    opened by mjchen611 22
  • ConvLSTM did not concat hidden from last round

    ConvLSTM did not concat hidden from last round

    In the structure presented in the paper, the hidden from last round is concat with input and then proceed for other operation. But it seems your LSTM did not use the hidden information from previous round.

    opened by Tom-the-Cat 7
  • Bad images in training

    Bad images in training

    While playing around with the sm5 dataset, I noticed some of them are badly rendered. individualimage Not sure if this will pose any problem for training, just wanted to point this out.

    opened by versatran01 7
  • Question about generator

    Question about generator

    In the top docstring of generator.py, you mentioned that

    The inference-generator architecture is conceptually
    similar to the encoder-decoder pair seen in variational
    autoencoders.
    

    I don't quite understand this part and I would really appreciate if you could explain a bit or point me at some related aritcles. For the generator I can see how it is similar to a decoder, where it takes latent z, query viewpoint v, and aggregated representation r and eventually output the image x_mu.

    But I'm a bit confused by the inference being the conterpart of encoder.

    opened by versatran01 7
  • Loss Change

    Loss Change

    Dear wohlert,

    May I consult you several questions?

    1. I tried to train this network on Mazes Data from https://github.com/deepmind/gqn-datasets. Actually it just contains 5% data, which is around 110000, instead of the full data. Is it right?

    2. I trained 30000 steps, but the elbo loss only converged to 6800 which has a big difference compared to around 7 in the supplementary. So may I ask what is the approximate value do you achieve on the data you used?

    3. From the visualisation based on Question 2, the reconstruction seems to be reasonable. But the sampling results is quite bad. Do you meet the same problem?

    Many thanks, Bing

    opened by BingCS 5
  • Questions on data preparing

    Questions on data preparing

    Hi, Wohlert:

    After the data conversion with your scripts, I visualize some of the images in the *.pt found pictures like this Figure_1-1

    What's wrong with that Also I'm confused about your batch operation , say if you batch the sequences as you convert them, does it mean that you won't batch them again when use dataloader?

    Thanks

    opened by Kyridiculous2 5
  • Training crashes at the same spot for both Shepard Metzler datasets

    Training crashes at the same spot for both Shepard Metzler datasets

    Some context:

    • I downloaded and converted the datasets via data.sh and set batch size to 12. Note that I am using TensorFlow 1.14 for reading the tfrecord files and converting them.
    • I use gpu.sh to run the training script. I set the batch size to either of [1,12,36,72] and DataParallel to True to use 4 GPUs

    But after a shrot time I get the following errors if I use any batch size higher than 1. This happens on iterations 40, 13 and 6 with batch sizes 12, 36 and 72. This happens for both Shepard Metzler datasets. Why I am getting these errors? Does batch size 1 on the training code mean reading one of the .pt.gz files? If so, setting batch size to 1 in the training script should actually mean 12. Would that be correct?

    Here's what I get for the data set with 5 parts when I set batch size to 36 for instance:

    Epoch [1/200]: [13/1856]   1%|▊                                                                                                                       , elbo=-2.1e+4, kl=827, mu=5e-6, sigma=2 [00:21<52:34]Current run is terminating due to exception: Caught RuntimeError in DataLoader worker process 13.
    Original Traceback (most recent call last):
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 8 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
    .
    Engine run is terminating due to exception: Caught RuntimeError in DataLoader worker process 13.
    Original Traceback (most recent call last):
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 8 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
    .
    Traceback (most recent call last):
      File "../run-gqn.py", line 183, in <module>
        trainer.run(train_loader, args.n_epochs)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 850, in run
        return self._internal_run()
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 952, in _internal_run
        self._handle_exception(e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 714, in _handle_exception
        self._fire_event(Events.EXCEPTION_RAISED, e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 607, in _fire_event
        func(self, *(event_args + args), **kwargs)
      File "../run-gqn.py", line 181, in handle_exception
        else: raise e
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 937, in _internal_run
        hours, mins, secs = self._run_once_on_dataset()
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 705, in _run_once_on_dataset
        self._handle_exception(e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 714, in _handle_exception
        self._fire_event(Events.EXCEPTION_RAISED, e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 607, in _fire_event
        func(self, *(event_args + args), **kwargs)
      File "../run-gqn.py", line 181, in handle_exception
        else: raise e
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 655, in _run_once_on_dataset
        batch = next(self._dataloader_iter)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 801, in __next__
        return self._process_data(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 846, in _process_data
        data.reraise()
      File "/usr/local/lib/python3.6/dist-packages/torch/_utils.py", line 385, in reraise
        raise self.exc_type(msg)
    RuntimeError: Caught RuntimeError in DataLoader worker process 13.
    Original Traceback (most recent call last):
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 8 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
    
    opened by Amir-Arsalan 4
  • AttributeError: 'int' object has no attribute 'size'

    AttributeError: 'int' object has no attribute 'size'

    in draw.py, I get this error at the line 118 (batch_size = z.size(0)) Sorry if this is obvious, thanks for help anyway.

    ~ % pip show torch :( Name: torch Version: 1.0.1.post2

    opened by DRM-Free 4
  • Increase dimension of viewpoint and representation

    Increase dimension of viewpoint and representation

    Thanks for this implementation. One question I have is when increasing the dimension of viewpoint and representation, you use torch.repeat. Is there any reason for this? Can one possibly use interpolate?

    In the original paper it says "when concatenating viewpoint v to an image or feature map, its values are ‘broadcast’ in the spatial dimensions to obtain the correct size. "

    The word 'broadcast' is not precisely defined, hence the question.

    opened by versatran01 4
  • Learning rate change

    Learning rate change

    Regarding line 113 of run-gqn.py. Does this change the learning rate of the Adam optimizer? This post shows something different

    https://stackoverflow.com/questions/48324152/pytorch-how-to-change-the-learning-rate-of-an-optimizer-at-any-given-moment-no

    opened by david-bernstein 4
  • Using the rooms data?

    Using the rooms data?

    I wanted to try your code on the rooms data but during conversion, I get these errors. What could I be doing wrong? Note that for the rooms data with moving camera I set the number of camera parameters to 7:

    Traceback (most recent call last):
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 119, in worker
        result = (True, func(*args, **kwds))
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 44, in mapstar
        return list(map(*args))
      File "tfrecord-converter.py", line 66, in convert
        for i, batch in enumerate(batch_process(record)):
      File "tfrecord-converter.py", line 29, in chunk
        for first in iterator:
      File "tfrecord-converter.py", line 40, in process
        'cameras': tf.FixedLenFeature(shape=SEQ_DIM * POSE_DIM, dtype=tf.float32)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 1019, in parse_single_example
        serialized, features, example_names, name
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 1063, in parse_single_example_v2_unoptimized
        return parse_single_example_v2(serialized, features, name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 2089, in parse_single_example_v2
        dense_defaults, dense_shapes, name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 2206, in _parse_single_example_v2_raw
        name=name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_parsing_ops.py", line 1164, in parse_single_example
        ctx=_ctx)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_parsing_ops.py", line 1260, in parse_single_example_eager_fallback
        attrs=_attrs, ctx=_ctx, name=name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py", line 67, in quick_execute
        six.raise_from(core._status_to_exception(e.code, message), None)
      File "<string>", line 3, in raise_from
    tensorflow.python.framework.errors_impl.InvalidArgumentError: Key: frames.  Can't parse serialized Example. [Op:ParseSingleExample]
    """
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "tfrecord-converter.py", line 98, in <module>
        pool.map(f, records)
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 266, in map
        return self._map_async(func, iterable, mapstar, chunksize).get()
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 644, in get
        raise self._value
    tensorflow.python.framework.errors_impl.InvalidArgumentError: Key: frames.  Can't parse serialized Example. [Op:ParseSingleExample]
    
    opened by Amir-Arsalan 3
Releases(0.1)
Owner
Jesper Wohlert
Jesper Wohlert
GNPy: Optical Route Planning and DWDM Network Optimization

GNPy is an open-source, community-developed library for building route planning and optimization tools in real-world mesh optical networks

Telecom Infra Project 140 Dec 19, 2022
Python scripts for performing 3D human pose estimation using the Mobile Human Pose model in ONNX.

Python scripts for performing 3D human pose estimation using the Mobile Human Pose model in ONNX.

Ibai Gorordo 99 Dec 31, 2022
a Pytorch easy re-implement of "YOLOX: Exceeding YOLO Series in 2021"

A pytorch easy re-implement of "YOLOX: Exceeding YOLO Series in 2021" 1. Notes This is a pytorch easy re-implement of "YOLOX: Exceeding YOLO Series in

91 Dec 26, 2022
This repository contains the official code of the paper Equivariant Subgraph Aggregation Networks (ICLR 2022)

Equivariant Subgraph Aggregation Networks (ESAN) This repository contains the official code of the paper Equivariant Subgraph Aggregation Networks (IC

Beatrice Bevilacqua 59 Dec 13, 2022
Basics of 2D and 3D Human Pose Estimation.

Human Pose Estimation 101 If you want a slightly more rigorous tutorial and understand the basics of Human Pose Estimation and how the field has evolv

Sudharshan Chandra Babu 293 Dec 14, 2022
🏆 The 1st Place Submission to AICity Challenge 2021 Natural Language-Based Vehicle Retrieval Track (Alibaba-UTS submission)

AI City 2021: Connecting Language and Vision for Natural Language-Based Vehicle Retrieval 🏆 The 1st Place Submission to AICity Challenge 2021 Natural

82 Dec 29, 2022
Repository for Driving Style Recognition algorithms for Autonomous Vehicles

Driving Style Recognition Using Interval Type-2 Fuzzy Inference System and Multiple Experts Decision Making Created by Iago Pachêco Gomes at USP - ICM

Iago Gomes 9 Nov 28, 2022
Harmonious Textual Layout Generation over Natural Images via Deep Aesthetics Learning

Harmonious Textual Layout Generation over Natural Images via Deep Aesthetics Learning Code for the paper Harmonious Textual Layout Generation over Nat

7 Aug 09, 2022
Official implementation of the paper Do pedestrians pay attention? Eye contact detection for autonomous driving

Do pedestrians pay attention? Eye contact detection for autonomous driving Official implementation of the paper Do pedestrians pay attention? Eye cont

VITA lab at EPFL 26 Nov 02, 2022
A bare-bones TensorFlow framework for Bayesian deep learning and Gaussian process approximation

Aboleth A bare-bones TensorFlow framework for Bayesian deep learning and Gaussian process approximation [1] with stochastic gradient variational Bayes

Gradient Institute 127 Dec 12, 2022
null

DeformingThings4D dataset Video | Paper DeformingThings4D is an synthetic dataset containing 1,972 animation sequences spanning 31 categories of human

208 Jan 03, 2023
Speech Separation Using an Asynchronous Fully Recurrent Convolutional Neural Network

Speech Separation Using an Asynchronous Fully Recurrent Convolutional Neural Network This repository is the official implementation of Speech Separati

Kai Li (李凯) 116 Nov 09, 2022
PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Thalles Silva 1.7k Dec 28, 2022
CVPRW 2021: How to calibrate your event camera

E2Calib: How to Calibrate Your Event Camera This repository contains code that implements video reconstruction from event data for calibration as desc

Robotics and Perception Group 104 Nov 16, 2022
The Unreasonable Effectiveness of Random Pruning: Return of the Most Naive Baseline for Sparse Training

[ICLR 2022] The Unreasonable Effectiveness of Random Pruning: Return of the Most Naive Baseline for Sparse Training The Unreasonable Effectiveness of

VITA 44 Dec 23, 2022
Streamlit App For Product Analysis - Streamlit App For Product Analysis

Streamlit_App_For_Product_Analysis Здравствуйте! Перед вами дашборд, позволяющий

Grigory Sirotkin 1 Jan 10, 2022
Source code of our TTH paper: Targeted Trojan-Horse Attacks on Language-based Image Retrieval.

Targeted Trojan-Horse Attacks on Language-based Image Retrieval Source code of our TTH paper: Targeted Trojan-Horse Attacks on Language-based Image Re

fine 7 Aug 23, 2022
🐦 Quickly annotate data from the comfort of your Jupyter notebook

🐦 pigeon - Quickly annotate data on Jupyter Pigeon is a simple widget that lets you quickly annotate a dataset of unlabeled examples from the comfort

Anastasis Germanidis 647 Jan 05, 2023
ThunderSVM: A Fast SVM Library on GPUs and CPUs

What's new We have recently released ThunderGBM, a fast GBDT and Random Forest library on GPUs. add scikit-learn interface, see here Overview The miss

Xtra Computing Group 1.4k Dec 22, 2022
Kroomsa: A search engine for the curious

Kroomsa A search engine for the curious. It is a search algorithm designed to en

Wingify 7 Jun 20, 2022