Multimodal Co-Attention Transformer (MCAT) for Survival Prediction in Gigapixel Whole Slide Images

Related tags

Deep LearningMCAT
Overview

Multimodal Co-Attention Transformer (MCAT) for Survival Prediction in Gigapixel Whole Slide Images

[ICCV 2021]

© Mahmood Lab - This code is made available under the GPLv3 License and is available for non-commercial academic purposes.

If you find our work useful in your research or if you use parts of this code please consider citing our paper:

@inproceedings{chen2021multimodal,
  title={Multimodal Co-Attention Transformer for Survival Prediction in Gigapixel Whole Slide Images},
  author={Chen, Richard J and Lu, Ming Y and Weng, Wei-Hung and Chen, Tiffany Y and Williamson, Drew FK and Manz, Trevor and Shady, Maha and Mahmood, Faisal},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={4015--4025},
  year={2021}
}

Updates:

  • 11/12/2021: Several users have raised concerns about the low c-Index for GBMLGG in SNN (Genomic Only). In using the gene families from MSigDB as gene signatures, IDH1 mutation was not included (key biomarker in distinguishing GBM and LGG).
  • 06/18/2021: Updated data preprocessing section for reproducibility.
  • 06/17/2021: Uploaded predicted risk scores on the validation folds for each models, and the evaluation script to compute the c-Index and Integrated AUC (I-AUC) validation metrics, found using the following Jupyter Notebook. Model checkpoints for MCAT are uploaded in the results directory.
  • 06/17/2021: Uploaded notebook detailing the MCAT network architecture, with sample input in the following following Jupyter Notebook, in which we print the shape of the tensors at each stage of MCAT.

Pre-requisites:

  • Linux (Tested on Ubuntu 18.04)
  • NVIDIA GPU (Tested on Nvidia GeForce RTX 2080 Ti x 16) with CUDA 11.0 and cuDNN 7.5
  • Python (3.7.7), h5py (2.10.0), matplotlib (3.1.1), numpy (1.18.1), opencv-python (4.1.1), openslide-python (1.1.1), openslide (3.4.1), pandas (1.1.3), pillow (7.0.0), PyTorch (1.6.0), scikit-learn (0.22.1), scipy (1.4.1), tensorflow (1.13.1), tensorboardx (1.9), torchvision (0.7.0), captum (0.2.0), shap (0.35.0)

Installation Guide for Linux (using anaconda)

1. Downloading TCGA Data

To download diagnostic WSIs (formatted as .svs files), molecular feature data and other clinical metadata, please refer to the NIH Genomic Data Commons Data Portal and the cBioPortal. WSIs for each cancer type can be downloaded using the GDC Data Transfer Tool.

2. Processing Whole Slide Images

To process WSIs, first, the tissue regions in each biopsy slide are segmented using Otsu's Segmentation on a downsampled WSI using OpenSlide. The 256 x 256 patches without spatial overlapping are extracted from the segmented tissue regions at the desired magnification. Consequently, a pretrained truncated ResNet50 is used to encode raw image patches into 1024-dim feature vectors, which we then save as .pt files for each WSI. The extracted features then serve as input (in a .pt file) to the network. The following folder structure is assumed for the extracted features vectors:

DATA_ROOT_DIR/
    └──TCGA_BLCA/
        ├── slide_1.pt
        ├── slide_2.pt
        └── ...
    └──TCGA_BRCA/
        ├── slide_1.pt
        ├── slide_2.pt
        └── ...
    └──TCGA_GBMLGG/
        ├── slide_1.pt
        ├── slide_2.pt
        └── ...
    └──TCGA_LUAD/
        ├── slide_1.ptd
        ├── slide_2.pt
        └── ...
    └──TCGA_UCEC/
        ├── slide_1.pt
        ├── slide_2.pt
        └── ...
    ...

DATA_ROOT_DIR is the base directory of all datasets / cancer type(e.g. the directory to your SSD). Within DATA_ROOT_DIR, each folder contains a list of .pt files for that dataset / cancer type.

3. Molecular Features and Genomic Signatures

Processed molecular profile features containing mutation status, copy number variation, and RNA-Seq abundance can be downloaded from the cBioPortal, which we include as CSV files in the following directory. For ordering gene features into gene embeddings, we used the following categorization of gene families (categorized via common features such as homology or biochemical activity) from MSigDB. Gene sets for homeodomain proteins and translocated cancer genes were not used due to overlap with transcription factors and oncogenes respectively. The curation of "genomic signatures" can be modified to curate genomic embedding that reflect unique biological functions.

4. Training-Validation Splits

For evaluating the algorithm's performance, we randomly partitioned each dataset using 5-fold cross-validation. Splits for each cancer type are found in the splits/5foldcv folder, which each contain splits_{k}.csv for k = 1 to 5. In each splits_{k}.csv, the first column corresponds to the TCGA Case IDs used for training, and the second column corresponds to the TCGA Case IDs used for validation. Alternatively, one could define their own splits, however, the files would need to be defined in this format. The dataset loader for using these train-val splits are defined in the get_split_from_df function in the Generic_WSI_Survival_Dataset class (inherited from the PyTorch Dataset class).

5. Running Experiments

To run experiments using the SNN, AMIL, and MMF networks defined in this repository, experiments can be run using the following generic command-line:

CUDA_VISIBLE_DEVICES=<DEVICE ID> python main.py --which_splits <SPLIT FOLDER PATH> --split_dir <SPLITS FOR CANCER TYPE> --mode <WHICH MODALITY> --model_type <WHICH MODEL>

Commands for all experiments / models can be found in the Commands.md file.

Owner
Mahmood Lab @ Harvard/BWH
AI for Pathology Image Analysis Lab @ HMS / BWH
Mahmood Lab @ Harvard/BWH
Code for 1st place solution in Sleep AI Challenge SNU Hospital

Sleep AI Challenge SNU Hospital 2021 Code for 1st place solution for Sleep AI Challenge (Note that the code is not fully organized) Refer to the notio

Saewon Yang 13 Jan 03, 2022
FewBit — a library for memory efficient training of large neural networks

FewBit FewBit — a library for memory efficient training of large neural networks. Its efficiency originates from storage optimizations applied to back

24 Oct 22, 2022
Implementation of "Learning Multi-Granular Hypergraphs for Video-Based Person Re-Identification"

hypergraph_reid Implementation of "Learning Multi-Granular Hypergraphs for Video-Based Person Re-Identification" If you find this help your research,

62 Dec 21, 2022
🔥 Cogitare - A Modern, Fast, and Modular Deep Learning and Machine Learning framework for Python

Cogitare is a Modern, Fast, and Modular Deep Learning and Machine Learning framework for Python. A friendly interface for beginners and a powerful too

Cogitare - Modern and Easy Deep Learning with Python 76 Sep 30, 2022
Compact Bilinear Pooling for PyTorch

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 07, 2022
Code for "Learning to Segment Rigid Motions from Two Frames".

rigidmask Code for "Learning to Segment Rigid Motions from Two Frames". ** This is a partial release with inference and evaluation code.

Gengshan Yang 157 Nov 21, 2022
RCT-ART is an NLP pipeline built with spaCy for converting clinical trial result sentences into tables through jointly extracting intervention, outcome and outcome measure entities and their relations.

Randomised controlled trial abstract result tabulator RCT-ART is an NLP pipeline built with spaCy for converting clinical trial result sentences into

2 Sep 16, 2022
This is a Pytorch implementation of paper: DropEdge: Towards Deep Graph Convolutional Networks on Node Classification

DropEdge: Towards Deep Graph Convolutional Networks on Node Classification This is a Pytorch implementation of paper: DropEdge: Towards Deep Graph Con

401 Dec 16, 2022
PyTorch implementation of Soft-DTW: a Differentiable Loss Function for Time-Series in CUDA

Soft DTW Loss Function for PyTorch in CUDA This is a Pytorch Implementation of Soft-DTW: a Differentiable Loss Function for Time-Series which is batch

Keon Lee 76 Dec 20, 2022
IndoNLI: A Natural Language Inference Dataset for Indonesian

IndoNLI: A Natural Language Inference Dataset for Indonesian This is a repository for data and code accompanying our EMNLP 2021 paper "IndoNLI: A Natu

15 Feb 10, 2022
Topic Discovery via Latent Space Clustering of Pretrained Language Model Representations

TopClus The source code used for Topic Discovery via Latent Space Clustering of Pretrained Language Model Representations, published in WWW 2022. Requ

Yu Meng 63 Dec 18, 2022
Implementation of Change-Based Exploration Transfer (C-BET)

Implementation of Change-Based Exploration Transfer (C-BET), as presented in Interesting Object, Curious Agent: Learning Task-Agnostic Exploration.

Simone Parisi 29 Dec 04, 2022
A torch.Tensor-like DataFrame library supporting multiple execution runtimes and Arrow as a common memory format

TorchArrow (Warning: Unstable Prototype) This is a prototype library currently under heavy development. It does not currently have stable releases, an

Facebook Research 536 Jan 06, 2023
MaskTrackRCNN for video instance segmentation based on mmdetection

MaskTrackRCNN for video instance segmentation Introduction This repo serves as the official code release of the MaskTrackRCNN model for video instance

411 Jan 05, 2023
This repository is for EMNLP 2021 paper: It is Not as Good as You Think! Evaluating Simultaneous Machine Translation on Interpretation Data

InterpretationData This repository is for our EMNLP 2021 paper: It is Not as Good as You Think! Evaluating Simultaneous Machine Translation on Interpr

4 Apr 21, 2022
In the case of your data having only 1 channel while want to use timm models

timm_custom Description In the case of your data having only 1 channel while want to use timm models (with or without pretrained weights), run the fol

2 Nov 26, 2021
Segmentation models with pretrained backbones. Keras and TensorFlow Keras.

Python library with Neural Networks for Image Segmentation based on Keras and TensorFlow. The main features of this library are: High level API (just

Pavel Yakubovskiy 4.2k Jan 09, 2023
Python code for the paper How to scale hyperparameters for quickshift image segmentation

How to scale hyperparameters for quickshift image segmentation Python code for the paper How to scale hyperparameters for quickshift image segmentatio

0 Jan 25, 2022
Lucid Sonic Dreams syncs GAN-generated visuals to music.

Lucid Sonic Dreams Lucid Sonic Dreams syncs GAN-generated visuals to music. By default, it uses NVLabs StyleGAN2, with pre-trained models lifted from

731 Jan 02, 2023
Code used for the results in the paper "ClassMix: Segmentation-Based Data Augmentation for Semi-Supervised Learning"

Code used for the results in the paper "ClassMix: Segmentation-Based Data Augmentation for Semi-Supervised Learning" Getting started Prerequisites CUD

70 Dec 02, 2022