Feature extraction made simple with torchextractor

Overview

torchextractor: PyTorch Intermediate Feature Extraction

PyPI - Python Version PyPI Read the Docs Upload Python Package GitHub

Introduction

Too many times some model definitions get remorselessly copy-pasted just because the forward function does not return what the person expects. You provide module names and torchextractor takes care of the extraction for you.It's never been easier to extract feature, add an extra loss or plug another head to a network. Ler us know what amazing things you build with torchextractor!

Installation

pip install torchextractor  # stable
pip install git+https://github.com/antoinebrl/torchextractor.git  # latest

Requirements:

  • Python >= 3.6+
  • torch >= 1.4.0

Usage

import torch
import torchvision
import torchextractor as tx

model = torchvision.models.resnet18(pretrained=True)
model = tx.Extractor(model, ["layer1", "layer2", "layer3", "layer4"])
dummy_input = torch.rand(7, 3, 224, 224)
model_output, features = model(dummy_input)
feature_shapes = {name: f.shape for name, f in features.items()}
print(feature_shapes)

# {
#   'layer1': torch.Size([1, 64, 56, 56]),
#   'layer2': torch.Size([1, 128, 28, 28]),
#   'layer3': torch.Size([1, 256, 14, 14]),
#   'layer4': torch.Size([1, 512, 7, 7]),
# }

See more examples Binder Open In Colab

Read the documentation

FAQ

• How do I know the names of the modules?

You can print all module names like this:

tx.list_module_names(model)

# OR

for name, module in model.named_modules():
    print(name)

• Why do some operations not get listed?

It is not possible to add hooks if operations are not defined as modules. Therefore, F.relu cannot be captured but nn.Relu() can.

• How can I avoid listing all relevant modules?

You can specify a custom filtering function to hook the relevant modules:

# Hook everything !
module_filter_fn = lambda module, name: True

# Capture of all modules inside first layer
module_filter_fn = lambda module, name: name.startswith("layer1")

# Focus on all convolutions
module_filter_fn = lambda module, name: isinstance(module, torch.nn.Conv2d)

model = tx.Extractor(model, module_filter_fn=module_filter_fn)

• Is it compatible with ONNX?

tx.Extractor is compatible with ONNX! This means you can also access intermediate features maps after the export.

Pro-tip: name the output nodes by using output_names when calling torch.onnx.export.

• Is it compatible with TorchScript?

Not yet, but we are working on it. Compiling registered hook of a module was just recently added in PyTorch v1.8.0.

• "One more thing!" 😉

By default we capture the latest output of the relevant modules, but you can specify your own custom operations.

For example, to accumulate features over 10 forward passes you can do the following:

import torch
import torchvision
import torchextractor as tx

model = torchvision.models.resnet18(pretrained=True)

def capture_fn(module, input, output, module_name, feature_maps):
    if module_name not in feature_maps:
        feature_maps[module_name] = []
    feature_maps[module_name].append(output)

extractor = tx.Extractor(model, ["layer3", "layer4"], capture_fn=capture_fn)

for i in range(20):
    for i in range(10):
        x = torch.rand(7, 3, 224, 224)
        model(x)
    feature_maps = extractor.collect()

    # Do your stuffs here

    # Discard collected elements
    extractor.clear_placeholder()

Contributing

All feedbacks and contributions are welcomed. Feel free to report an issue or to create a pull request!

If you want to get hands-on:

  1. (Fork and) clone the repo.
  2. Create a virtual environment: virtualenv -p python3 .venv && source .venv/bin/activate
  3. Install dependencies: pip install -r requirements.txt && pip install -r requirements-dev.txt
  4. Hook auto-formatting tools: pre-commit install
  5. Hack as much as you want!
  6. Run tests: python -m unittest discover -vs ./tests/
  7. Share your work and create a pull request.

To Build documentation:

cd docs
pip install requirements.txt
make html
You might also like...
Deep Image Search is an AI-based image search engine that includes deep transfor learning features Extraction and tree-based vectorized search.
Deep Image Search is an AI-based image search engine that includes deep transfor learning features Extraction and tree-based vectorized search.

Deep Image Search - AI-Based Image Search Engine Deep Image Search is an AI-based image search engine that includes deep transfer learning features Ex

Cross-media Structured Common Space for Multimedia Event Extraction (ACL2020)
Cross-media Structured Common Space for Multimedia Event Extraction (ACL2020)

Cross-media Structured Common Space for Multimedia Event Extraction Table of Contents Overview Requirements Data Quickstart Citation Overview The code

Source code for paper "Document-Level Relation Extraction with Adaptive Thresholding and Localized Context Pooling", AAAI 2021

ATLOP Code for AAAI 2021 paper Document-Level Relation Extraction with Adaptive Thresholding and Localized Context Pooling. If you make use of this co

Training data extraction on GPT-2

Training data extraction from GPT-2 This repository contains code for extracting training data from GPT-2, following the approach outlined in the foll

This repository contains the code for our fast polygonal building extraction from overhead images pipeline.
This repository contains the code for our fast polygonal building extraction from overhead images pipeline.

Polygonal Building Segmentation by Frame Field Learning We add a frame field output to an image segmentation neural network to improve segmentation qu

Adversarial Robustness Toolbox (ART) - Python Library for Machine Learning Security - Evasion, Poisoning, Extraction, Inference - Red and Blue Teams
Adversarial Robustness Toolbox (ART) - Python Library for Machine Learning Security - Evasion, Poisoning, Extraction, Inference - Red and Blue Teams

Adversarial Robustness Toolbox (ART) is a Python library for Machine Learning Security. ART provides tools that enable developers and researchers to defend and evaluate Machine Learning models and applications against the adversarial threats of Evasion, Poisoning, Extraction, and Inference. ART supports all popular machine learning frameworks (TensorFlow, Keras, PyTorch, MXNet, scikit-learn, XGBoost, LightGBM, CatBoost, GPy, etc.), all data types (images, tables, audio, video, etc.) and machine learning tasks (classification, object detection, speech recognition, generation, certification, etc.).

Implementation for our AAAI2021 paper (Entity Structure Within and Throughout: Modeling Mention Dependencies for Document-Level Relation Extraction).
Implementation for our AAAI2021 paper (Entity Structure Within and Throughout: Modeling Mention Dependencies for Document-Level Relation Extraction).

SSAN Introduction This is the pytorch implementation of the SSAN model (see our AAAI2021 paper: Entity Structure Within and Throughout: Modeling Menti

An Efficient Implementation of Analytic Mesh Algorithm for 3D Iso-surface Extraction from Neural Networks
An Efficient Implementation of Analytic Mesh Algorithm for 3D Iso-surface Extraction from Neural Networks

AnalyticMesh Analytic Marching is an exact meshing solution from neural networks. Compared to standard methods, it completely avoids geometric and top

[ACL 20] Probing Linguistic Features of Sentence-level Representations in Neural Relation Extraction

REval Table of Contents Introduction Overview Requirements Installation Probing Usage Citation License 🎓 Introduction REval is a simple framework for

Comments
  • Only extracting part of the intermediate feature with DataParallel

    Only extracting part of the intermediate feature with DataParallel

    Hi @antoinebrl,

    I am using torch.nn.DataParallel on a 2-GPU machine with a batch size of N. Data parallel training will split the input data batch into 2 pieces sequentially and sends them to GPUs.

    When using torchextractor to obtain the intermediate feature, the input data size and the output size are both N as expected, but the feature size becomes N/2. Does this mean we only extract the features of one GPU? I'm not sure because I didn't find an exact match.

    Can you please explain why this happens? Maybe the normal behavior is returning features from all GPUs or from a specified one?

    A minimal example to reproduce:

    import torch
    import torchvision
    import torchextractor as tx
    
    model = torchvision.models.resnet18(pretrained=True)
    model_gpu = torch.nn.DataParallel(torchvision.models.resnet18(pretrained=True))
    model_gpu.cuda()
    
    model = tx.Extractor(model, ["layer1"])
    model_gpu = tx.Extractor(model_gpu, ["module.layer1"])
    dummy_input = torch.rand(8, 3, 224, 224)
    _, features = model(dummy_input)
    _, features_gpu = model_gpu(dummy_input)
    feature_shapes = {name: f.shape for name, f in features.items()}
    print(feature_shapes)
    feature_shapes_gpu = {name: f.shape for name, f in features_gpu.items()}
    print(feature_shapes_gpu)
    
    # {'layer1': torch.Size([8, 64, 56, 56])}
    # {'module.layer1': torch.Size([4, 64, 56, 56])}
    
    opened by wydwww 5
Releases(v0.3.0)
Pytorch-Swin-Unet-V2 - a modified version of Swin Unet based on Swin Transfomer V2

Swin Unet V2 Swin Unet V2 is a modified version of Swin Unet arxiv based on Swin

Chenxu Peng 26 Dec 03, 2022
Deep Learning Interviews book: Hundreds of fully solved job interview questions from a wide range of key topics in AI.

This book was written for you: an aspiring data scientist with a quantitative background, facing down the gauntlet of the interview process in an increasingly competitive field. For most of you, the

4.1k Dec 28, 2022
Various operations like path tracking, counting, etc by using yolov5

Object-tracing-with-YOLOv5 Various operations like path tracking, counting, etc by using yolov5

Pawan Valluri 5 Nov 28, 2022
A large dataset of 100k Google Satellite and matching Map images, resembling pix2pix's Google Maps dataset.

Larger Google Sat2Map dataset This dataset extends the aerial ⟷ Maps dataset used in pix2pix (Isola et al., CVPR17). The provide script download_sat2m

34 Dec 28, 2022
This is my research project for the Irving Center for Cancer Dynamics/Azizi Lab, Columbia University.

bayesian_uncertainty This is my research project for the Irving Center for Cancer Dynamics/Azizi Lab, Columbia University. In this project I build a s

Max David Gupta 1 Feb 13, 2022
Does Oversizing Improve Prosumer Profitability in a Flexibility Market? - A Sensitivity Analysis using PV-battery System

Does Oversizing Improve Prosumer Profitability in a Flexibility Market? - A Sensitivity Analysis using PV-battery System The possibilities to involve

Babu Kumaran Nalini 0 Nov 19, 2021
ICCV2021 Oral SA-ConvONet: Sign-Agnostic Optimization of Convolutional Occupancy Networks

Sign-Agnostic Convolutional Occupancy Networks Paper | Supplementary | Video | Teaser Video | Project Page This repository contains the implementation

64 Jan 05, 2023
Pacman-AI - AI project designed by UC Berkeley. Designed reflex and minimax agents for the game Pacman.

Pacman AI Jussi Doherty CAP 4601 - Introduction to Artificial Intelligence - Fall 2020 Python version 3.0+ Source of this project This repo contains a

Jussi Doherty 1 Jan 03, 2022
NDE: Climate Modeling with Neural Diffusion Equation, ICDM'21

Climate Modeling with Neural Diffusion Equation Introduction This is the repository of our accepted ICDM 2021 paper "Climate Modeling with Neural Diff

Jeehyun Hwang 5 Dec 18, 2022
DirectVoxGO reconstructs a scene representation from a set of calibrated images capturing the scene.

DirectVoxGO reconstructs a scene representation from a set of calibrated images capturing the scene. We achieve NeRF-comparable novel-view synthesis quality with super-fast convergence.

sunset 709 Dec 31, 2022
Template repository for managing machine learning research projects built with PyTorch-Lightning

Tutorial Repository with a minimal example for showing how to deploy training across various compute infrastructure.

Sidd Karamcheti 3 Feb 11, 2022
Bare bones use-case for deploying a containerized web app (built in streamlit) on AWS.

Containerized Streamlit web app This repository is featured in a 3-part series on Deploying web apps with Streamlit, Docker, and AWS. Checkout the blo

Collin Prather 62 Jan 02, 2023
CLIP+FFT text-to-image

Aphantasia This is a text-to-image tool, part of the artwork of the same name. Based on CLIP model, with FFT parameterizer from Lucent library as a ge

vadim epstein 690 Jan 02, 2023
This is the codebase for Diffusion Models Beat GANS on Image Synthesis.

This is the codebase for Diffusion Models Beat GANS on Image Synthesis.

OpenAI 3k Dec 26, 2022
PyBrain - Another Python Machine Learning Library.

PyBrain -- the Python Machine Learning Library =============================================== INSTALLATION ------------ Quick answer: make sure you

2.8k Dec 31, 2022
How will electric vehicles affect traffic congestion and energy consumption: an integrated modelling approach

EV-charging-impact This repository contains the code that has been used for the Queue modelling for the paper "How will electric vehicles affect traff

7 Nov 30, 2022
Applying curriculum to meta-learning for few shot classification

Curriculum Meta-Learning for Few-shot Classification We propose an adaptation of the curriculum training framework, applicable to state-of-the-art met

Stergiadis Manos 3 Oct 25, 2022
Winners of the Facebook Image Similarity Challenge

Winners of the Facebook Image Similarity Challenge

DrivenData 111 Jan 05, 2023
🔥 Cogitare - A Modern, Fast, and Modular Deep Learning and Machine Learning framework for Python

Cogitare is a Modern, Fast, and Modular Deep Learning and Machine Learning framework for Python. A friendly interface for beginners and a powerful too

Cogitare - Modern and Easy Deep Learning with Python 76 Sep 30, 2022
git《Commonsense Knowledge Base Completion with Structural and Semantic Context》(AAAI 2020) GitHub: [fig1]

Commonsense Knowledge Base Completion with Structural and Semantic Context Code for the paper Commonsense Knowledge Base Completion with Structural an

AI2 96 Nov 05, 2022