HyperSeg: Patch-wise Hypernetwork for Real-time Semantic Segmentation Official PyTorch Implementation

Overview

HyperSeg - Official PyTorch Implementation

Teaser Example segmentations on the PASCAL VOC dataset.

This repository contains the source code for the real-time semantic segmentation method described in the paper:

HyperSeg: Patch-wise Hypernetwork for Real-time Semantic Segmentation
Conference on Computer Vision and Pattern Recognition (CVPR), 2021
Yuval Nirkin, Lior Wolf, Tal Hassner
Paper

Abstract: We present a novel, real-time, semantic segmentation network in which the encoder both encodes and generates the parameters (weights) of the decoder. Furthermore, to allow maximal adaptivity, the weights at each decoder block vary spatially. For this purpose, we design a new type of hypernetwork, composed of a nested U-Net for drawing higher level context features, a multi-headed weight generating module which generates the weights of each block in the decoder immediately before they are consumed, for efficient memory utilization, and a primary network that is composed of novel dynamic patch-wise convolutions. Despite the usage of less-conventional blocks, our architecture obtains real-time performance. In terms of the runtime vs. accuracy trade-off, we surpass state of the art (SotA) results on popular semantic segmentation benchmarks: PASCAL VOC 2012 (val. set) and real-time semantic segmentation on Cityscapes, and CamVid.

Installation

Install the following packages:

conda install pytorch torchvision cudatoolkit=11.1 -c pytorch -c conda-forge
pip install opencv-python ffmpeg-python

Add the parent directory of the repository to PYTHONPATH.

Models

Template Dataset Resolution mIoU (%) FPS Link
HyperSeg-L PASCAL VOC 512x512 80.6 (val) - download
HyperSeg-M CityScapes 1024x512 76.2 (val) 36.9 download
HyperSeg-S CityScapes 1536x768 78.2 (val) 16.1 download
HyperSeg-S CamVid 768x576 78.4 (test) 38.0 download
HyperSeg-L CamVid 1024x768 79.1 (test) 16.6 -

The models FPS was measured on an NVIDIA GeForce GTX 1080TI GPU.

Either download the models under /weights or adjust the model variable in the test configuration files.

Datasets

Dataset # Images Classes Resolution Link
PASCAL VOC 10,582 21 up to 500x500 auto downloaded
CityScapes 5,000 19 960x720 download
CamVid 701 12 2048x1024 download

Either download the datasets under /data or adjust the data_dir variable in the configuration files.

Training

To train the HyperSeg-M model on Cityscapes, set the exp_dir and data_dir paths in cityscapes_efficientnet_b1_hyperseg-m.py and run:

python configs/train/cityscapes_efficientnet_b1_hyperseg-m.py

Testing

Testing a model after training

For example testing the HyperSeg-M model on Cityscapes validation set:

python test.py 'checkpoints/cityscapes/cityscapes_efficientnet_b1_hyperseg-m' \
-td "hyper_seg.datasets.cityscapes.CityscapesDataset('data/cityscapes',split='val',mode='fine')" \
-it "seg_transforms.LargerEdgeResize([512,1024])"

Testing a pretrained model

For example testing the PASCAL VOC HyperSeg-L model using the available test configuration:

python configs/test/vocsbd_efficientnet_b3_hyperseg-l.py
Comments
  • TypeError: an integer is required (got type tuple)

    TypeError: an integer is required (got type tuple)

    Hi, Thanks for your sharing. There is a error when I run the training code which is the sample.

    TRAINING: Epoch: 1 / 360; LR: 1.0e-03; losses: [total: 3.1942 (3.1942); ] bench: [iou: 0.0106 (0.0106); ] : 0%| | 1/1000 [00:07<1:59:20, 7.17s/batches]

    Traceback (most recent call last): File "/home/tt/zyj_ws/hyperseg/configs/train/cityscapes_efficientnet_b1_hyperseg-m.py", line 43, in main(exp_dir, train_dataset=train_dataset, val_dataset=val_dataset, train_img_transforms=train_img_transforms, File "/home/tt/zyj_ws/hyperseg/train.py", line 248, in main epoch_loss, epoch_iou = proces_epoch(train_loader, train=True) File "/home/tt/zyj_ws/hyperseg/train.py", line 104, in proces_epoch for i, (input, target) in enumerate(pbar): File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/tqdm/std.py", line 1185, in iter for obj in iterable: File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 517, in next data = self._next_data() File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1179, in _next_data return self._process_data(data) File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1225, in _process_data data.reraise() File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torch/_utils.py", line 429, in reraise raise self.exc_type(msg) TypeError: Caught TypeError in DataLoader worker process 1. Original Traceback (most recent call last): File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop data = fetcher.fetch(index) File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 44, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/home/tt/zyj_ws/hyperseg/datasets/cityscapes.py", line 220, in getitem image, target = self.transforms(image, target) File "/home/tt/zyj_ws/hyperseg/datasets/seg_transforms.py", line 78, in call input = list(t(*input)) File "/home/tt/zyj_ws/hyperseg/datasets/seg_transforms.py", line 334, in call lbl = F.pad(lbl, (int(self.size[1] - lbl.size[0]), 0), self.lbl_fill, self.padding_mode).copy() File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torchvision/transforms/functional.py", line 426, in pad return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode) File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py", line 153, in pad image = ImageOps.expand(img, border=padding, **opts) File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/PIL/ImageOps.py", line 403, in expand draw.rectangle((0, 0, width - 1, height - 1), outline=color, width=border) File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/PIL/ImageDraw.py", line 259, in rectangle self.draw.draw_rectangle(xy, ink, 0, width) TypeError: an integer is required (got type tuple)

    According to the information from the Internet, this may be a problem in the transformation. Can you give me some suggestion on how to do it?

    opened by consideragain 12
  • ValueError: Unknown resampling filter (InterpolationMode.BICUBIC).

    ValueError: Unknown resampling filter (InterpolationMode.BICUBIC).

    I have some problems when I train the model on my own data and I don't know if it is because that I make some mistakes when I load my data. Here are the details: Traceback (most recent call last): File "E:\zjy\hyperseg-main\hyperseg\train.py", line 254, in main epoch_loss, epoch_iou = proces_epoch(val_loader, train=False) File "E:\zjy\hyperseg-main\hyperseg\train.py", line 104, in proces_epoch for i, (input, target) in enumerate(pbar): File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\tqdm\std.py", line 1130, in iter for obj in iterable: File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\torch\utils\data\dataloader.py", line 521, in next data = self._next_data() File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\torch\utils\data\dataloader.py", line 1203, in _next_data return self._process_data(data) File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\torch\utils\data\dataloader.py", line 1229, in _process_data data.reraise() File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\torch_utils.py", line 438, in reraise raise exception ValueError: Caught ValueError in DataLoader worker process 0. Original Traceback (most recent call last): File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\torch\utils\data_utils\worker.py", line 287, in _worker_loop data = fetcher.fetch(index) File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\torch\utils\data_utils\fetch.py", line 49, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\torch\utils\data_utils\fetch.py", line 49, in data = [self.dataset[idx] for idx in possibly_batched_index] File "E:\zjy\hyperseg-main\hyperseg\datasets\Massachusetts.py", line 104, in getitem img, target = self.transforms(img, target) File "E:\zjy\hyperseg-main\hyperseg\datasets\seg_transforms.py", line 78, in call input = list(t(*input)) File "E:\zjy\hyperseg-main\hyperseg\datasets\seg_transforms.py", line 199, in call img = larger_edge_resize(img, self.size, self.interpolation) File "E:\zjy\hyperseg-main\hyperseg\datasets\seg_transforms.py", line 174, in larger_edge_resize return img.resize(size[::-1], interpolation) File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\PIL\Image.py", line 1861, in resize raise ValueError( ValueError: Unknown resampling filter (InterpolationMode.BICUBIC). Use Image.NEAREST (0), Image.LANCZOS (1), Image.BILINEAR (2), Image.BICUBIC (3), Image.BOX (4) or Image.HAMMING (5)

    opened by JoyeeZhang 9
  • ModuleNotFoundError:No module named ‘hyperseg’

    ModuleNotFoundError:No module named ‘hyperseg’

    I encountered this error when I was training the model, and when I deleted all the "hyperseg" in the code, this error was not reported. Is this the right thing to do?

    opened by Rebufleming 5
  • No module named 'hyper_seg'

    No module named 'hyper_seg'

    Because the training failed, I tried to download cityscapes_efficientnet_b1_hyperseg-m.pth to hyperseg/weights, and renamed it to model_best.pth, I try python test.py 'weights'
    -td "hyper_seg.datasets.cityscapes.CityscapesDataset('data/cityscapes',split='val',mode='fine')"\ -it "seg_transforms.LargerEdgeResize([512,1024])"
    --gpus 0 image Looking forward to your reply!

    opened by leo-hao 5
  • Dataset not found or incomplete.

    Dataset not found or incomplete.

    (liuhaomag) [email protected]:~/data/LH/hyperseg$ python configs/train/cityscapes_efficientnet_b1_hyperseg-m.py
    Traceback (most recent call last):
      File "configs/train/cityscapes_efficientnet_b1_hyperseg-m.py", line 46, in <module>
        scheduler=scheduler, pretrained=pretrained, model=model, criterion=criterion, batch_scheduler=batch_scheduler)
      File "/home/cv428/data/LH/hyperseg/train.py", line 187, in main
        train_dataset = obj_factory(train_dataset, transforms=train_transforms)
      File "/home/cv428/data/LH/hyperseg/utils/obj_factory.py", line 57, in obj_factory
        return obj_exp(*args, **kwargs)
      File "/home/cv428/data/LH/hyperseg/datasets/cityscapes.py", line 156, in __init__
        raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
    RuntimeError: Dataset not found or incomplete. Please make sure all required folders for the specified "split" and "mode" are inside the "root" directory
     
    
    opened by leo-hao 5
  • Saving for C++ inference

    Saving for C++ inference

    First, great work !

    I was trying to train in Python and save it for C++ inference. The classic approach doesn't work:

    annotation_script_module = torch.jit.script(model)
    annotation_script_module.save("my_path")
    

    Do you have any suggestion on how to do it?

    opened by Roios 4
  • MIoU on Pascal VOC is about 0.7 with 50 epoches

    MIoU on Pascal VOC is about 0.7 with 50 epoches

    Thank you for your sharing. I have tried to retrain the model on Pascal VOC dataset, but the miou is only about 0.7 after 50 epoches and it hardly increased, may I know is it correct? or should I continue the training for 160 epoches?

    opened by dansonc 3
  • About Test bug

    About Test bug

    While I use the test script of vocdataset, it will occur a bug----[Errno 2] No such file or directory: 'data/vocsbd/Vocdevkit/V0C2012/3PEGImages/2007_000033.jp9. How to deal with it?

    opened by 870572761 2
  • Question about freezing updates for model

    Question about freezing updates for model

    Hello, thank you so much for releasing your code! I am currently training the model on a custom dataset, and I was wondering if it is possible to freeze the gradient updates for a select number of classes? I am currently zeroing out the gradients from the decoder level 4 BatchNorm3 layer because that is the only layer that has the size of the number of classes within my custom dataset, but when I attempt this, it appears that the weights of the classes that I try to freeze are still changing.

    opened by ma53ma 2
  • CVE-2007-4559 Patch

    CVE-2007-4559 Patch

    Patching CVE-2007-4559

    Hi, we are security researchers from the Advanced Research Center at Trellix. We have began a campaign to patch a widespread bug named CVE-2007-4559. CVE-2007-4559 is a 15 year old bug in the Python tarfile package. By using extract() or extractall() on a tarfile object without sanitizing input, a maliciously crafted .tar file could perform a directory path traversal attack. We found at least one unsantized extractall() in your codebase and are providing a patch for you via pull request. The patch essentially checks to see if all tarfile members will be extracted safely and throws an exception otherwise. We encourage you to use this patch or your own solution to secure against CVE-2007-4559. Further technical information about the vulnerability can be found in this blog.

    If you have further questions you may contact us through this projects lead researcher Kasimir Schulz.

    opened by TrellixVulnTeam 1
  • The difference between the Resize(512, 1024) and LargerEdgeResize(512, 1024)

    The difference between the Resize(512, 1024) and LargerEdgeResize(512, 1024)

    Hello! @YuvalNirkin,

    I would like to know what is the difference between the torchvision.transforms.Resize function and your designed hyperseg.datasets.seg_transforms.LargerEdgeResize function?

    The results of global accuray and mIoU are slightly difference when I use these two transfromation separably in test phrase.

    Wish your reply.

    opened by J-JunChen 1
  • Another bug about the script of train camvid

    Another bug about the script of train camvid

    While I solve the bug of camvid dataset, I find another bug of the script of train camvid. When running the script of train camvid, it will occur "AttributeError: 'RandomRotation' object has no attribute 'resample' ". Tips: the script of test camvid can run. I don't know how to deal with it.

    opened by 870572761 1
  • Assertion `t >= 0 && t < n_classes` failed.

    Assertion `t >= 0 && t < n_classes` failed.

    Sorry to bother you, I have encontered a probelm: Assertion t >= 0 && t < n_classes failed when I try to use your model to my datasets, I have changed the num_class in my datasets, but I still face the problem, can u help me? here is my change: self.classes = list(range(2)) self.image_classes = calc_classes_per_image(self.masks, 2, cache_file)

    opened by xings-sdnu 1
  • signal_index is not getting updated

    signal_index is not getting updated

    The variable signal_index is not getting updated inside the init_signal2weights function in hyperseg_v1_0.py. Hence, it is assigning singal_index = 0 for every meta block in the decoder.

    Because of this inside the apply_signal2weights function in each meta block, we will not be able to use all the channels of the input signal channels. Essentially we are using [0:signal_channels] of the whole signal where signal_channels of each block is always less than input signal channels.

            # Inside apply_signal2weights function
            w = self.signal2weights(s[:, self.signal_index:self.signal_index + self.signal_channels])[:, :self.hyper_params]
    

    So, is the signal_index supposed to be zero for every meta block?

    opened by Dhamodhar-DDR 1
Owner
Yuval Nirkin
Deep Learning Researcher
Yuval Nirkin
ICRA 2021 - Robust Place Recognition using an Imaging Lidar

Robust Place Recognition using an Imaging Lidar A place recognition package using high-resolution imaging lidar. For best performance, a lidar equippe

Tixiao Shan 293 Dec 27, 2022
N-gram models- Unsmoothed, Laplace, Deleted Interpolation

N-gram models- Unsmoothed, Laplace, Deleted Interpolation

Ravika Nagpal 1 Jan 04, 2022
An integration of several popular automatic augmentation methods, including OHL (Online Hyper-Parameter Learning for Auto-Augmentation Strategy) and AWS (Improving Auto Augment via Augmentation Wise Weight Sharing) by Sensetime Research.

An integration of several popular automatic augmentation methods, including OHL (Online Hyper-Parameter Learning for Auto-Augmentation Strategy) and AWS (Improving Auto Augment via Augmentation Wise

45 Dec 08, 2022
Torch implementation of SegNet and deconvolutional network

Torch implementation of SegNet and deconvolutional network

Fedor Chervinskii 5 Jul 17, 2020
wmctrl ported to Python Ctypes

work in progress wmctrl is a command that can be used to interact with an X Window manager that is compatible with the EWMH/NetWM specification. wmctr

Iyad Ahmed 22 Dec 31, 2022
Implementation of Change-Based Exploration Transfer (C-BET)

Implementation of Change-Based Exploration Transfer (C-BET), as presented in Interesting Object, Curious Agent: Learning Task-Agnostic Exploration.

Simone Parisi 29 Dec 04, 2022
A hybrid SOTA solution of LiDAR panoptic segmentation with C++ implementations of point cloud clustering algorithms. ICCV21, Workshop on Traditional Computer Vision in the Age of Deep Learning

ICCVW21-TradiCV-Survey-of-LiDAR-Cluster Motivation In contrast to popular end-to-end deep learning LiDAR panoptic segmentation solutions, we propose a

YimingZhao 103 Nov 22, 2022
abess: Fast Best-Subset Selection in Python and R

abess: Fast Best-Subset Selection in Python and R Overview abess (Adaptive BEst Subset Selection) library aims to solve general best subset selection,

297 Dec 21, 2022
PixelPyramids: Exact Inference Models from Lossless Image Pyramids (ICCV 2021)

PixelPyramids: Exact Inference Models from Lossless Image Pyramids This repository contains the PyTorch implementation of the paper PixelPyramids: Exa

Visual Inference Lab @TU Darmstadt 8 Dec 11, 2022
Visual Tracking by TridenAlign and Context Embedding

Visual Tracking by TridentAlign and Context Embedding (TACT) Test code for "Visual Tracking by TridentAlign and Context Embedding" Janghoon Choi, Juns

Janghoon Choi 32 Aug 25, 2021
Python codes for Lite Audio-Visual Speech Enhancement.

Lite Audio-Visual Speech Enhancement (Interspeech 2020) Introduction This is the PyTorch implementation of Lite Audio-Visual Speech Enhancement (LAVSE

Shang-Yi Chuang 85 Dec 01, 2022
A PoC Corporation Relationship Knowledge Graph System on top of Nebula Graph.

Corp-Rel is a PoC of Corpartion Relationship Knowledge Graph System. It's built on top of the Open Source Graph Database: Nebula Graph with a dataset

Wey Gu 20 Dec 11, 2022
Combining Latent Space and Structured Kernels for Bayesian Optimization over Combinatorial Spaces

This repository contains source code for the paper Combining Latent Space and Structured Kernels for Bayesian Optimization over Combinatorial Spaces a

9 Nov 21, 2022
Self-Correcting Quantum Many-Body Control using Reinforcement Learning with Tensor Networks

Self-Correcting Quantum Many-Body Control using Reinforcement Learning with Tensor Networks This repository contains the code and data for the corresp

Friederike Metz 7 Apr 23, 2022
This app is a simple example of using Strealit to create a financial data web app.

Streamlit Demo: Finance Chart This app is a simple example of using Streamlit to create a financial data web app. This demo use streamlit, pandas and

91 Jan 02, 2023
Code for the paper: Hierarchical Reinforcement Learning With Timed Subgoals, published at NeurIPS 2021

Hierarchical reinforcement learning with Timed Subgoals (HiTS) This repository contains code for reproducing experiments from our paper "Hierarchical

Autonomous Learning Group 21 Dec 03, 2022
PyDeepFakeDet is an integrated and scalable tool for Deepfake detection.

PyDeepFakeDet An integrated and scalable library for Deepfake detection research. Introduction PyDeepFakeDet is an integrated and scalable Deepfake de

Junke, Wang 49 Dec 11, 2022
Deep learning toolbox based on PyTorch for hyperspectral data classification.

Deep learning toolbox based on PyTorch for hyperspectral data classification.

Nicolas 304 Dec 28, 2022
Pytorch implementation of BRECQ, ICLR 2021

BRECQ Pytorch implementation of BRECQ, ICLR 2021 @inproceedings{ li&gong2021brecq, title={BRECQ: Pushing the Limit of Post-Training Quantization by Bl

Yuhang Li 148 Dec 28, 2022
Repository for the semantic WMI loss

Installation: pip install -e . Installing DL2: First clone DL2 in a separate directory and install it using the following commands: git clone https:/

Nick Hoernle 4 Sep 15, 2022