Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch.

Overview

SE3 Transformer - Pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 results and other drug discovery applications.

Install

$ pip install se3-transformer-pytorch

Usage

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 512,
    heads = 8,
    depth = 6,
    dim_head = 64,
    num_degrees = 4,
    valid_radius = 10
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
mask  = torch.ones(1, 1024).bool()

out = model(feats, coors, mask) # (1, 1024, 512)

Potential example usage in Alphafold2, as outlined here

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True
)

atom_feats = torch.randn(2, 32, 64)
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refinement = model(atom_feats, coors, mask, return_type = 1) # (2, 32, 3)

You can also let the base transformer class take care of embedding the type 0 features being passed in. Assuming they are atoms

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,       # 28 unique atoms
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refinement = model(atoms, coors, mask, return_type = 1) # (2, 32, 3)

If you think the net could further benefit from positional encoding, you can featurize your positions in space and pass it in as follows.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 2,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True  # reduce out the final dimension
)

atom_feats  = torch.randn(2, 32, 64, 1) # b x n x d x type0
coors_feats = torch.randn(2, 32, 64, 3) # b x n x d x type1

# atom features are type 0, predicted coordinates are type 1
features = {'0': atom_feats, '1': coors_feats}
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refinement = model(features, coors, mask, return_type = 1) # (2, 32, 3) - equivariant to input type 1 features and coordinates

Edges

To offer edge information to SE3 Transformers (say bond types between atoms), you just have to pass in two more keyword arguments on initialization.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,
    dim = 64,
    num_edge_tokens = 4,       # number of edge type, say 4 bond types
    edge_dim = 16,             # dimension of edge embedding
    depth = 2,
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
bonds = torch.randint(0, 4, (2, 32, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

pred = model(atoms, coors, mask, edges = bonds, return_type = 0) # (2, 32, 1)

Caching

By default, the basis vectors are cached. However, if there is ever the need to clear the cache, you simply have to set the environmental flag CLEAR_CACHE to some value on initiating the script

$ CLEAR_CACHE=1 python train.py

Or you can try deleting the cache directory, which should exist at

$ rm -rf ~/.cache.equivariant_attention

Testing

$ python setup.py pytest

Credit

This library is largely a port of Fabian's official repository, but without the DGL library.

Citations

@misc{fuchs2020se3transformers,
    title   = {SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks}, 
    author  = {Fabian B. Fuchs and Daniel E. Worrall and Volker Fischer and Max Welling},
    year    = {2020},
    eprint  = {2006.10503},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Breaking equivariance

    Breaking equivariance

    Hi, Thanks a lot for your work!!

    I was running some equivariance tests on your implementation of the SE3-Transformer and found that it is not always conserved. It does not break every time and unfortunately I do not know where the bug is.

    I have appended an image with an example of equivariance not being conserved.

    image

    To generate this image I used the following code:

        import numpy as np
        import matplotlib.pyplot as plt
        import torch
        
        from se3_transformer_pytorch import SE3Transformer
    

    Make some data

        zline = np.arange(0, 2, 0.05)
        xline = np.sin(zline * 2 * np.pi) 
        yline = np.cos(zline * 2 * np.pi)
        points = np.array([xline, yline, zline])
        geom = torch.tensor(points.transpose())[None,:].float()
        feat = torch.randint(0, 20, (1, geom.shape[1],1)).float()
        
        def rot_matrix(x):
            # Rotation matrix
            a ,b ,c = 2*np.pi*x
            return np.array([
                [np.cos(a)*np.cos(b), np.cos(a)*np.sin(b)*np.sin(c)- np.sin(a)*np.cos(c), np.cos(a)*np.sin(b)*np.cos(c)+ np.sin(a)*np.sin(c)],
                [np.sin(a)*np.cos(b), np.sin(a)*np.sin(b)*np.sin(c)+ np.cos(a)*np.cos(c), np.sin(a)*np.sin(b)*np.cos(c)- np.cos(a)*np.sin(c)],
                [-np.sin(b)         , np.cos(b)*np.sin(c)                               , np.cos(b)*np.cos(c)                               ]
            ])
    

    Initialize model

        mdl = SE3Transformer(
            dim = 1,
            depth = 3,
            input_degrees = 1,
            num_degrees = 2,
            output_degrees = 2,
            reduce_dim_out = True,
        )
        
        def model(geom,feat):
            return geom + mdl(feat,geom, return_type = 1)
    

    Check Rotation Invariance:

        with torch.no_grad():
            
            Q = torch.tensor(rot_matrix(np.random.random(3))).float()
            prerotated = model(geom @ Q, feat).squeeze().detach().numpy().transpose()
            posrotated = (model(geom, feat) @ Q).squeeze().detach().numpy().transpose()
        
            fig = plt.figure(dpi = 200)
            ax = plt.axes(projection="3d")
        
            ax.plot3D(prerotated[0], prerotated[1], prerotated[2], "r", linewidth=1.1)
            ax.plot3D(posrotated[0], posrotated[1], posrotated[2], "b", linewidth=0.5)
        
            plt.legend(["Pre-Rotated", "Post-Rotated"])
            plt.show()
    

    Check Translation Invariance:

        with torch.no_grad():
            
            x0 = 1*torch.rand(3)
            prerotated = model(geom + x0, feat).squeeze().detach().numpy().transpose()
            posrotated = (model(geom, feat) + x0).squeeze().detach().numpy().transpose()
        
            fig = plt.figure(dpi = 200)
            ax = plt.axes(projection="3d")
        
            ax.plot3D(prerotated[0], prerotated[1], prerotated[2], "r", linewidth=1.1)
            ax.plot3D(posrotated[0], posrotated[1], posrotated[2], "b", linewidth=0.5)
        
            plt.legend(["Pre-Translated", "Post-Translated"])
            plt.show()
    

    Hope this helps!!

    opened by brennanaba 18
  • How to populate input variable length data

    How to populate input variable length data

    Hello, thank you for your work. I am using the implementation of your SE3-Transformer as part of my project, but I have encountered some problems when processing variable length data input into your model. I do not know how to fill in features, coordinates and masks to meet the needs of the model.

    The processing of input data in my DataSet is as follows: image I fill the coordinates and features with zeros to ensure that their input dimensions are (246,3) and (246,20) respectively, and fill the mask with true and false bool types. In this case,there are 49 valid values,.So the shapes with valid features and coordinates are (49,20) and (49,3) and the rest are padded to zero. The first 49 of the mask is true and the rest is false

    Then I check through the torch.utils.data.DataLoader get input dimension is no problem: image My network looks like this, A typical binary classifier,I output the result of se3:

    image

    However, the output of se3 is NaN as shown in the figure below. image The first 48 have results, the rest are NaN, not clear what the problem is

    The next batch may all be NaN due to model weight changes image The output of SE3 causes all my losses to be NaN

    I think these NaN are probably caused by my incorrect filling method. I have also tried to fill in "1", but it was also ineffective. Maybe the mask I typed in can't have false, and I'm confused about that. Could you please tell me how to correctly fill coordinates, features and masks or give me some advice on using your SE3 model to handle variable length data. Hope this helps! ! thank you

    opened by zyk19981118 8
  • CPU/CUDA masking error

    CPU/CUDA masking error

    Hi - nice work - I was just testing out your code and ran into the following error, but only in the backward pass:

    RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'mask' in call to th_masked_scatter_bool

    When using the nightly Pytorch, the error message is:

    RuntimeError: Tensor for argument #2 'mask' is on CPU, but expected it to be on GPU (while checking arguments for masked_scatter_)

    I'm pretty sure I don't have any tensors in CPU memory, but not sure if this is a bug in your SE3 code or a Pytorch issue. My gut feeling is this is a Pytorch/autograd issue, but I just don't know these particular Pytorch ops well enough to be sure. Tried both 1.7.1 release Pytorch and the latest nightly. Seems like there has been recent work on masked_scatter according to Pytorch issues.

    opened by denjots 8
  • small bug

    small bug

    https://github.com/lucidrains/se3-transformer-pytorch/blob/7c79998e4d84ec6bd6b6d4b916c6bf30b870b75b/se3_transformer_pytorch/se3_transformer_pytorch.py#L301

    should be if isinstance(m,nn.Linear)

    nbd, but thought you might wanna know.

    opened by MattMcPartlon 5
  • INTERNAL ASSERT FAILED

    INTERNAL ASSERT FAILED

    Hi - I'm not sure if this is actually a bug or if I'm just expecting too much, but the following code snippet bombs with a PyTorch internal assert failure:

    RuntimeError: sub_iter.strides(0)[0] == 0INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1618643394934/work/aten/src/ATen/native/cuda/Reduce.cuh":929, please report a bug to PyTorch.

    That's using the latest PyTorch nightly. Tried with both CUDA 10.2 and 11.1.

    Looking at the PyTorch bug reports, this seems to suggest that there is massive internal memory allocation being triggered when summing over a large tensor. It seems to have been sitting open for a year and I'm not sure if they even see it as an actual bug there. Certainly it doesn't seem to be a priority.

    Problem with this is that this isn't a particularly large data input, or a large model (to say the least!), and my GPU has 40 Gb of RAM, so, scalability of this particular transformer model seems very minimal as it currently stands. Changing the dim from 128 to 64 does at least allow it to run.

    import torch
    from se3_transformer_pytorch import SE3Transformer
    
    model = SE3Transformer(dim = 128, heads = 1, depth = 1, dim_head = 1, num_degrees = 1, input_degrees=1, output_degrees=1).cuda()
    
    feats = torch.randn(1, 200, 128).cuda()
    coords = torch.randn(1, 200, 3).cuda()
    
    out = model(feats, coords)
    
    
    opened by denjots 4
  • question about non scalar output

    question about non scalar output

    Hello,

    I am interested in using your implementation on my dataset. There are two things I want to check with you

    1. The property I want to predict is a 3x3 symmetric PSD matrix.
    2. There are some edge features (one categorical feature, one continuous feature) besides the coordinate difference

    I was wondering does the current se3-transformer can work with such a scenario? Thanks!

    opened by Chen-Cai-OSU 3
  • Reversible flag odd results

    Reversible flag odd results

    Sorry if it's expected behaviour again, but with a slight tweak of your new example code I get unexpected results when your new reversible option is used to return type 1 data. Here's the code I'm running...

    import torch
    from se3_transformer_pytorch import SE3Transformer
    
    model = SE3Transformer(
        num_tokens = 20,
        dim = 32,
        dim_head = 32,
        heads = 4,
        depth = 12,             # 12 layers
        input_degrees = 1,
        num_degrees = 2,
        output_degrees = 2,
        reduce_dim_out = True,
        reversible = True       # set reversible to True
    ).cuda()
    
    atoms = torch.randint(0, 4, (1, 50)).cuda()
    coors = torch.randn(1, 50, 3).cuda()
    mask  = torch.ones(1, 50).bool().cuda()
    
    pred = model(atoms, coors, mask = mask, return_type = 1)
    
    loss = pred.mean()
    print(loss)
    loss.backward()
    
    

    Without the reversible flag, the loss is close to zero as might be expected as the input coords are centered on the origin. However, when reversible is set, a very high value is produced which is going to blow up training if this was for real. Is this an expected side effect of reversible nets? Is some kind of normalization essential in that case? Or am I just doing something wrong here?

    opened by denjots 3
  • faster loop

    faster loop

    My small grain on sand for this project ;) : at least don't deal with python appends which are 2x slower than list comprehension.

    If this is of any help, here are some considerations (there might be misunderstandings on my side due to the decorators and so on):

    • i have my reservations on the utility of the line 62 in utils.py
    • would't make more sense to start the for i modified (line 122 of utils.py) in reverse order, then use the cached calculations for the lpmv() ?
    • same case (loop in reverse order) for the for in line 148 of basis.py?
    • if using the scipy.special.poch (which can deal with np arrays) instead of the custom pochhammer implementation, all operations inside get_spherical_harmonics_element are vectorizeable but the lpmv function call.
      • My sense is that the lpmv, get_spherical_harmonics_element and get_spherical_harmonics could be all wrapped in a single function (lower reusability / extension... so maybe doing the inverse loop order and caching is enough).
    opened by hypnopump 1
  • Question about continuous edge features

    Question about continuous edge features

    Thanks for all of the work!

    I am working on a simple proof of concept with your model. Ideally, I would like to perform multidimensional scaling. i.e. given a distance matrix recover corresponding coordinates.

    I was wondering if distance information could be passed as an edge feature (continuous rather than categorical information). Is it possible to do this with the current implementation?

    Thanks again and I appreciate the help!

    opened by MattMcPartlon 1
  • denoise.py bugfix

    denoise.py bugfix

    Fixes issue related to the constructor

    Traceback (most recent call last):
      File "/home/jcastellanos/projects/se3-transformer-pytorch/denoise.py", line 22, in <module>
        transformer = SE3Transformer(
      File "/home/jcastellanos/projects/se3-transformer-pytorch/se3_transformer_pytorch/se3_transformer_pytorch.py", line 1072, in __init__
        self.num_degrees = num_degrees if exists(num_degrees) else (max(hidden_fiber_dict.keys()) + 1)
    AttributeError: 'NoneType' object has no attribute 'keys'
    
    opened by javierbq 0
  • CUDA out of memory

    CUDA out of memory

    Thanks for your great job!

    The se3-transformer is powerful, but seems to be memory exhaustive.

    I built a model with the following parameters, and got "CUDA out of memory error" when I run it on the GPU(Nvidia V100 / 32G).

    model = SE3Transformer( dim = 20, heads = 4, depth = 2, dim_head = 5, num_degrees = 2, valid_radius = 5 )

    num_points = 512
    feats = torch.randn(1, num_points, 20)
    coors = torch.randn(1, num_points, 3)
    mask = torch.ones(1, num_points).bool()
    

    Does this error relate to the version of pytorch? and how can I fix it?

    opened by PengCheng-NUDT 0
  • SE3Transformer constructor hangs

    SE3Transformer constructor hangs

    I am trying to run an example from the README. The code is:

    import torch
    from se3_transformer_pytorch import SE3Transformer
    
    print('Initialising model...')
    model = SE3Transformer(
        dim = 512,
        heads = 8,
        depth = 6,
        dim_head = 64,
        num_degrees = 4,
        valid_radius = 10
    )
    
    print('Running model...')
    feats = torch.randn(1, 1024, 512)
    coors = torch.randn(1, 1024, 3)
    mask  = torch.ones(1, 1024).bool()
    
    out = model(feats, coors, mask) # (1, 1024, 512)
    

    The output hangs on 'Initialising model...' and eventually the kernel dies.

    Any ideas why this would be happening?

    Here is my pip freeze:

    anyio==3.2.1
    argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613036642480/work
    astunparse==1.6.3
    async-generator==1.10
    attrs @ file:///tmp/build/80754af9/attrs_1620827162558/work
    axial-positional-embedding==0.2.1
    Babel==2.9.1
    backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
    biopython==1.79
    bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work
    cached-property @ file:///tmp/build/80754af9/cached-property_1600785575025/work
    certifi==2021.5.30
    cffi @ file:///tmp/build/80754af9/cffi_1613246939562/work
    chardet==4.0.0
    click==8.0.1
    configparser==5.0.2
    decorator==4.4.2
    defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
    dgl-cu101==0.4.3.post2
    dgl-cu110==0.6.1
    docker-pycreds==0.4.0
    egnn-pytorch==0.2.6
    einops==0.3.0
    En-transformer==0.3.8
    entrypoints==0.3
    equivariant-attention @ file:///workspace/projects/se3-transformer-public
    filelock==3.0.12
    gitdb==4.0.7
    GitPython==3.1.18
    graph-transformer-pytorch==0.0.1
    h5py @ file:///tmp/build/80754af9/h5py_1622088444809/work
    huggingface-hub==0.0.12
    idna==2.10
    importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617877314848/work
    ipykernel @ file:///tmp/build/80754af9/ipykernel_1596206598566/work/dist/ipykernel-5.3.4-py3-none-any.whl
    ipython @ file:///tmp/build/80754af9/ipython_1617118429768/work
    ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
    ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1610481889018/work
    jedi==0.17.0
    Jinja2 @ file:///tmp/build/80754af9/jinja2_1621238361758/work
    joblib==1.0.1
    json5==0.9.6
    jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
    jupyter==1.0.0
    jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1616770841739/work
    jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1616615302928/work
    jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1612213308260/work
    jupyter-server==1.8.0
    jupyter-tensorboard==0.2.0
    jupyterlab==3.0.16
    jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
    jupyterlab-server==2.6.0
    jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work
    jupytext==1.11.3
    lie-learn @ git+https://github.com/AMLab-Amsterdam/[email protected]
    llvmlite==0.36.0
    local-attention==1.4.1
    markdown-it-py==1.1.0
    MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621528142364/work
    mdit-py-plugins==0.2.8
    mdtraj==1.9.6
    mistune @ file:///tmp/build/80754af9/mistune_1594373098390/work
    mkl-fft==1.3.0
    mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853974840/work
    mkl-service==2.3.0
    mp-nerf==0.1.11
    nbclassic==0.3.1
    nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work
    nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914821128/work
    nbformat @ file:///tmp/build/80754af9/nbformat_1617383369282/work
    nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work
    networkx==2.5.1
    notebook @ file:///tmp/build/80754af9/notebook_1621523661196/work
    numba==0.53.1
    numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1620831194891/work
    packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work
    pandas==1.2.4
    pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120451932/work
    parso @ file:///tmp/build/80754af9/parso_1617223946239/work
    pathtools==0.1.2
    performer-pytorch==1.0.11
    pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
    pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
    ProDy==2.0
    prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1623189609245/work
    promise==2.3
    prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work
    protobuf==3.17.3
    psutil==5.8.0
    ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
    py3Dmol==0.9.1
    pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
    Pygments @ file:///tmp/build/80754af9/pygments_1621606182707/work
    pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work
    pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141707582/work
    python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work
    pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work
    PyYAML==5.4.1
    pyzmq==20.0.0
    qtconsole @ file:///tmp/build/80754af9/qtconsole_1623278325812/work
    QtPy==1.9.0
    regex==2021.4.4
    requests==2.25.1
    sacremoses==0.0.45
    scipy @ file:///tmp/build/80754af9/scipy_1618852618548/work
    se3-transformer-pytorch==0.8.10
    Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work
    sentry-sdk==1.1.0
    shortuuid==1.0.1
    sidechainnet==0.6.0
    six @ file:///tmp/build/80754af9/six_1623709665295/work
    smmap==4.0.0
    sniffio==1.2.0
    subprocess32==3.5.4
    terminado==0.9.4
    testpath @ file:///home/ktietz/src/ci/testpath_1611930608132/work
    tokenizers==0.10.3
    toml==0.10.2
    torch==1.9.0
    tornado @ file:///tmp/build/80754af9/tornado_1606942283357/work
    tqdm==4.61.1
    traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work
    transformers==4.8.0
    typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work
    urllib3==1.26.5
    wandb==0.10.32
    wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
    webencodings==0.5.1
    websocket-client==1.1.0
    widgetsnbextension==3.5.1
    zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work
    

    Here is a summary of my system info (lshw -short):

    H/W path    Device  Class      Description
    ==========================================
                        system     Computer
    /0                  bus        Motherboard
    /0/0                memory     59GiB System memory
    /0/1                processor  Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
    /0/100              bridge     440FX - 82441FX PMC [Natoma]
    /0/100/1            bridge     82371SB PIIX3 ISA [Natoma/Triton II]
    /0/100/1.1          storage    82371SB PIIX3 IDE [Natoma/Triton II]
    /0/100/1.3          bridge     82371AB/EB/MB PIIX4 ACPI
    /0/100/2            display    GD 5446
    /0/100/3            network    Elastic Network Adapter (ENA)
    /0/100/1e           display    GK210GL [Tesla K80]
    /0/100/1f           generic    Xen Platform Device
    /1          eth0    network    Ethernet interface
    
    opened by mpdprot 1
  • Whether SE3 needs pre-training

    Whether SE3 needs pre-training

    Thank you for your work. I used your reproduced SE3 as a part of my model, but the current test effect is not very good. I guess it may be because I do not have a good understanding of your model. Here are my questions:

    1. Does your model need pre-training?
    2. Can I train SE3 Transformer with the full connection layer that comes after it? Good advice is also welcome
    opened by zyk19981118 2
  • multiple molecules cases

    multiple molecules cases

    Hi,

    I use normally dataloader from PyG to handle my molecules dataset.

    Can you provide an example to make a real multistep epoch model please ?

    I run your sample code it's working but I need better understand input format as well as output format which for me would be x,y,z ?

    image

    thanks

    opened by thegodone 2
Releases(0.9.0)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Code repo for "RBSRICNN: Raw Burst Super-Resolution through Iterative Convolutional Neural Network" (Machine Learning and the Physical Sciences workshop in NeurIPS 2021).

RBSRICNN: Raw Burst Super-Resolution through Iterative Convolutional Neural Network An official PyTorch implementation of the RBSRICNN network as desc

Rao Muhammad Umer 6 Nov 14, 2022
Accompanying code for the paper "A Kernel Test for Causal Association via Noise Contrastive Backdoor Adjustment".

#backdoor-HSIC (bd_HSIC) Accompanying code for the paper "A Kernel Test for Causal Association via Noise Contrastive Backdoor Adjustment". To generate

Robert Hu 0 Nov 25, 2021
The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer"

Shuffle Transformer The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer" Introduction Very recently, window-

87 Nov 29, 2022
Deep Latent Force Models

Deep Latent Force Models This repository contains a PyTorch implementation of the deep latent force model (DLFM), presented in the paper, Compositiona

Tom McDonald 5 Oct 26, 2022
Repository for the paper "Online Domain Adaptation for Occupancy Mapping", RSS 2020

RSS 2020 - Online Domain Adaptation for Occupancy Mapping Repository for the paper "Online Domain Adaptation for Occupancy Mapping", Robotics: Science

Anthony 26 Sep 22, 2022
Py-FEAT: Python Facial Expression Analysis Toolbox

Py-FEAT is a suite for facial expressions (FEX) research written in Python. This package includes tools to detect faces, extract emotional facial expressions (e.g., happiness, sadness, anger), facial

Computational Social Affective Neuroscience Laboratory 147 Jan 06, 2023
Captcha-tensorflow - Image Captcha Solving Using TensorFlow and CNN Model. Accuracy 90%+

Captcha Solving Using TensorFlow Introduction Solve captcha using TensorFlow. Learn CNN and TensorFlow by a practical project. Follow the steps, run t

Jackon Yang 869 Jan 06, 2023
Quickly and easily create / train a custom DeepDream model

Dream-Creator This project aims to simplify the process of creating a custom DeepDream model by using pretrained GoogleNet models and custom image dat

55 Dec 27, 2022
UT-Sarulab MOS prediction system using SSL models

UTMOS: UTokyo-SaruLab MOS Prediction System Official implementation of "UTMOS: UTokyo-SaruLab System for VoiceMOS Challenge 2022" submitted to INTERSP

sarulab-speech 58 Nov 22, 2022
DeLag: Detecting Latency Degradation Patterns in Service-based Systems

DeLag: Detecting Latency Degradation Patterns in Service-based Systems Replication package of the work "DeLag: Detecting Latency Degradation Patterns

SEALABQualityGroup @ University of L'Aquila 2 Mar 24, 2022
ppo_pytorch_cpp - an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch

PPO Pytorch C++ This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment t

Martin Huber 59 Dec 09, 2022
NIMA: Neural IMage Assessment

PyTorch NIMA: Neural IMage Assessment PyTorch implementation of Neural IMage Assessment by Hossein Talebi and Peyman Milanfar. You can learn more from

Kyryl Truskovskyi 293 Dec 30, 2022
Discord-Protect is a simple discord bot allowing you to have some security on your discord server by ordering a captcha to the user who joins your server.

Discord-Protect Discord-Protect is a simple discord bot allowing you to have some security on your discord server by ordering a captcha to the user wh

Tir Omar 2 Oct 28, 2021
This is the pytorch code for the paper Curious Representation Learning for Embodied Intelligence.

Curious Representation Learning for Embodied Intelligence This is the pytorch code for the paper Curious Representation Learning for Embodied Intellig

19 Oct 19, 2022
Very simple NCHW and NHWC conversion tool for ONNX. Change to the specified input order for each and every input OP. Also, change the channel order of RGB and BGR. Simple Channel Converter for ONNX.

scc4onnx Very simple NCHW and NHWC conversion tool for ONNX. Change to the specified input order for each and every input OP. Also, change the channel

Katsuya Hyodo 16 Dec 22, 2022
Use tensorflow to implement a Deep Neural Network for real time lane detection

LaneNet-Lane-Detection Use tensorflow to implement a Deep Neural Network for real time lane detection mainly based on the IEEE IV conference paper "To

MaybeShewill-CV 1.9k Jan 08, 2023
Scalable and Elastic Deep Reinforcement Learning Using PyTorch. Please star. 🔥

ElegantRL “小雅”: Scalable and Elastic Deep Reinforcement Learning ElegantRL is developed for researchers and practitioners with the following advantage

AI4Finance Foundation 2.5k Jan 05, 2023
Pytorch Geometric Tutorials

Pytorch Geometric Tutorials

Antonio Longa 648 Jan 08, 2023
Weakly-supervised object detection.

Wetectron Wetectron is a software system that implements state-of-the-art weakly-supervised object detection algorithms. Project CVPR'20, ECCV'20 | Pa

NVIDIA Research Projects 342 Jan 05, 2023
This repository contains several image-to-image translation models, whcih were tested for RGB to NIR image generation. The models are Pix2Pix, Pix2PixHD, CycleGAN and PointWise.

RGB2NIR_Experimental This repository contains several image-to-image translation models, whcih were tested for RGB to NIR image generation. The models

5 Jan 04, 2023