pytorch_fft

  • Post author:
  • Post category:其他



https://github.com/locuslab/pytorch_fft

A PyTorch wrapper for CUDA FFTs

License


A package that provides a PyTorch C extension for performing batches of 2D CuFFT transformations, by

Eric Wong



Installation

This package is on PyPi. Install with

pip install pytorch-fft

.



Usage

  • From the

    pytorch_fft.fft

    module, you can use

    fft2

    and

    ifft2

    to do the forward and backward FFT transformations.
  • The input tensors are required to have >= 3 dimensions (n1 x … x nk x row x col) where

    n1 x ... x nk

    is the batch of FFT transformations, and

    row x col

    are the dimension of each transformation.
import torch
import pytorch_fft.fft as fft

A_real, A_imag = torch.randn(3,4,5).cuda(), torch.zeros(3,4,5).cuda()
B_real, B_imag = fft.fft2(A_real, A_imag)
fft.ifft2(B_real, B_imag) # equals (A_real, A_imag)



Notes

  • This follows NumPy semantics, so

    ifft2(fft2(x)) = x

    . Note that CuFFT semantics for inverse FFT only flip the sign of the transform, but it is not a true inverse.
  • This function is

    NOT

    a PyTorch autograd

    Function

    , and as a result is not backprop-able. What this package allows you to do is call CuFFT on PyTorch Tensors.
  • The code currently only implements batched 2D transformation, for Complex to Complex transformations. If you require a different number of dimensions, the source code can be easily extended.



Repository contents

  • pytorch_fft/src: C source code
  • pytorch_fft/fft: Python convenience wrapper
  • build.py: compilation file
  • test.py: tests against NumPy FFTs



Issues and Contributions

If you have any issues or feature requests,

file an issue

or

send in a PR

.