PyTorch trainer and model for Sequence Classification

Overview

PyTorch-trainer-and-model-for-Sequence-Classification

After cloning the repository, modify your training data so that the training data is a .csv file and it has 2 columns: Text and Label

In the below example, we will assume that our training data has 3 labels, the name of our training data file is train_data.csv

Example Usage

Import dependencies

import pandas as pd
import numpy as np
from transformers import AutoModel, AutoTokenizer, AutoConfig

from EarlyStopping import *
from modelling import *
from utils import *

Specify arguments

args.pretrained_path will be the path of our pretrained language model

class args:
    fold = 0
    pretrained_path = 'bert-base-uncased'
    max_length = 400
    train_batch_size = 16
    val_batch_size = 64
    epochs = 5
    learning_rate = 1e-5
    accumulation_steps = 2
    num_splits = 5

Create train and validation data

In this example we will train the model using cross-validation. We will split our training data into args.num_splits folds.

df = pd.read_csv('./train_data.csv')
df = create_k_folds(df, args.num_splits)

df_train = df[df['kfold'] == args.fold].reset_index(drop = True)
df_valid = df[df['kfold'] == args.fold].reset_index(drop = True)

Load the language model and its tokenizer

config = AutoConfig.from_pretrained(args.path)
tokenizer = AutoTokenizer.from_pretrained(args.path)
model_transformer = AutoModel.from_pretrained(args.path)

Prepare train and validation dataloaders

features = []
for i in range(len(df_train)):
    features.append(prepare_features(tokenizer, df_train.iloc[i, :].to_dict(), args.max_length))
    
train_dataset = CreateDataset(features)
train_dataloader = create_dataloader(train_dataset, args.train_batch_size, 'train')

features = []
for i in range(len(df_valid)):
    features.append(prepare_features(tokenizer, df_valid.iloc[i, :].to_dict(), args.max_length))
    
val_dataset = CreateDataset(features)
val_dataloader = create_dataloader(val_dataset, args.val_batch_size, 'val')

Use EarlyStopping and customize the score function

NOTE: The customized score function should have 2 parameters: the logits, and the actual label

def accuracy(logits, labels):
    logits = logits.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    pred_classes = np.argmax(logits * (1 / np.sum(logits, axis = -1)).reshape(logits.shape[0], 1), axis = -1)
    pred_classes = pred_classes.reshape(labels.shape)
    
    return np.sum(pred_classes == labels) / labels.shape[0]

es = EarlyStopping(mode = 'max', patience = 3, monitor = 'val_acc', out_path = 'model.bin')
es.monitor_score_function = accuracy

Create and train the model

Calling the fit method, the training process will begin

model = Model(config, model_transformer, num_labels = 3)
model.to('cuda')
num_train_steps = int(len(train_dataset) / args.train_batch_size * args.epochs)
model.fit(args.epochs, args.learning_rate, num_train_steps, args.accumulation_steps, 
          train_dataloader, val_dataloader, es)

NOTE: To complete the cross-validation training process, run the code above again with args.fold equals 1, 2, ..., args.num_splits - 1

Owner
NhanTieu
NhanTieu
Ultra-lightweight human body posture key point CNN model. ModelSize:2.3MB HUAWEI P40 NCNN benchmark: 6ms/img,

Ultralight-SimplePose Support NCNN mobile terminal deployment Based on MXNET(=1.5.1) GLUON(=0.7.0) framework Top-down strategy: The input image is t

223 Dec 27, 2022
Piotr - IoT firmware emulation instrumentation for training and research

Piotr: Pythonic IoT exploitation and Research Introduction to Piotr Piotr is an emulation helper for Qemu that provides a convenient way to create, sh

Damien Cauquil 51 Nov 09, 2022
Contains source code for the winning solution of the xView3 challenge

Winning Solution for xView3 Challenge This repository contains source code and pretrained models for my (Eugene Khvedchenya) solution to xView 3 Chall

Eugene Khvedchenya 51 Dec 30, 2022
Distributed DataLoader For Pytorch Based On Ray

Dpex——用户无感知分布式数据预处理组件 一、前言 随着GPU与CPU的算力差距越来越大以及模型训练时的预处理Pipeline变得越来越复杂,CPU部分的数据预处理已经逐渐成为了模型训练的瓶颈所在,这导致单机的GPU配置的提升并不能带来期望的线性加速。预处理性能瓶颈的本质在于每个GPU能够使用的C

Dalong 23 Nov 02, 2022
Official PyTorch Implementation of Mask-aware IoU and maYOLACT Detector [BMVC2021]

The official implementation of Mask-aware IoU and maYOLACT detector. Our implementation is based on mmdetection. Mask-aware IoU for Anchor Assignment

Kemal Oksuz 46 Sep 29, 2022
StyleGAN2-ada for practice

This version of the newest PyTorch-based StyleGAN2-ada is intended mostly for fellow artists, who rarely look at scientific metrics, but rather need a working creative tool. Tested on Python 3.7 + Py

vadim epstein 170 Nov 16, 2022
Experiments and examples converting Transformers to ONNX

Experiments and examples converting Transformers to ONNX This repository containes experiments and examples on converting different Transformers to ON

Philipp Schmid 4 Dec 24, 2022
Code and models used in "MUSS Multilingual Unsupervised Sentence Simplification by Mining Paraphrases".

Multilingual Unsupervised Sentence Simplification Code and pretrained models to reproduce experiments in "MUSS: Multilingual Unsupervised Sentence Sim

Facebook Research 81 Dec 29, 2022
Vehicle direction identification consists of three module detection , tracking and direction recognization.

Vehicle-direction-identification Vehicle direction identification consists of three module detection , tracking and direction recognization. Algorithm

5 Nov 15, 2022
Deployment of PyTorch chatbot with Flask

Chatbot Deployment with Flask and JavaScript In this tutorial we deploy the chatbot I created in this tutorial with Flask and JavaScript. This gives 2

Patrick Loeber (Python Engineer) 107 Dec 29, 2022
Code for "Training Neural Networks with Fixed Sparse Masks" (NeurIPS 2021).

Code for "Training Neural Networks with Fixed Sparse Masks" (NeurIPS 2021).

Varun Nair 37 Dec 30, 2022
SSL_SLAM2: Lightweight 3-D Localization and Mapping for Solid-State LiDAR (mapping and localization separated) ICRA 2021

SSL_SLAM2 Lightweight 3-D Localization and Mapping for Solid-State LiDAR (Intel Realsense L515 as an example) This repo is an extension work of SSL_SL

Wang Han 王晗 1.3k Jan 08, 2023
IndoNLI: A Natural Language Inference Dataset for Indonesian

IndoNLI: A Natural Language Inference Dataset for Indonesian This is a repository for data and code accompanying our EMNLP 2021 paper "IndoNLI: A Natu

15 Feb 10, 2022
O2O-Afford: Annotation-Free Large-Scale Object-Object Affordance Learning (CoRL 2021)

O2O-Afford: Annotation-Free Large-Scale Object-Object Affordance Learning Object-object Interaction Affordance Learning. For a given object-object int

Kaichun Mo 26 Nov 04, 2022
Code to reproduce the results for Compositional Attention

Compositional-Attention This repository contains the official implementation for the paper Compositional Attention: Disentangling Search and Retrieval

Sarthak Mittal 58 Nov 30, 2022
Python Environment for Bayesian Learning

Pebl is a python library and command line application for learning the structure of a Bayesian network given prior knowledge and observations. Pebl in

Abhik Shah 103 Jul 14, 2022
PyTorch Implement of Context Encoders: Feature Learning by Inpainting

Context Encoders: Feature Learning by Inpainting This is the Pytorch implement of CVPR 2016 paper on Context Encoders 1) Semantic Inpainting Demo Inst

321 Dec 25, 2022
A simple rest api serving a deep learning model that classifies human gender based on their faces. (vgg16 transfare learning)

this is a simple rest api serving a deep learning model that classifies human gender based on their faces. (vgg16 transfare learning)

crispengari 5 Dec 09, 2021
🛰️ Awesome Satellite Imagery Datasets

Awesome Satellite Imagery Datasets List of aerial and satellite imagery datasets with annotations for computer vision and deep learning. Newest datase

Christoph Rieke 3k Jan 03, 2023
High-quality implementations of standard and SOTA methods on a variety of tasks.

Uncertainty Baselines The goal of Uncertainty Baselines is to provide a template for researchers to build on. The baselines can be a starting point fo

Google 1.1k Dec 30, 2022