Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to devices without CUDA #50

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions modules/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@
from torch.utils.cpp_extension import load

_src_path = path.join(path.dirname(path.abspath(__file__)), "src")

if torch.cuda.is_available():
source_files = [ "inplace_abn.cpp", "inplace_abn_cpu.cpp",
"inplace_abn_cuda.cu", "inplace_abn_cuda_half.cu"]
loaf_kwargs = { "extra_cuda_cflags": ["--expt-extended-lambda"] }
else:
source_files = [ "inplace_abn.cpp", "inplace_abn_cpu.cpp" ]
load_kwargs = { "with_cuda": False }


_backend = load(name="inplace_abn",
extra_cflags=["-O3"],
sources=[path.join(_src_path, f) for f in [
"inplace_abn.cpp",
"inplace_abn_cpu.cpp",
"inplace_abn_cuda.cu",
"inplace_abn_cuda_half.cu"
]],
extra_cuda_cflags=["--expt-extended-lambda"])
sources=[path.join(_src_path, f) for f in source_files],
**load_kwargs)

# Activation names
ACT_RELU = "relu"
Expand Down
53 changes: 36 additions & 17 deletions modules/src/inplace_abn.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,19 @@
#include <vector>

std::vector<at::Tensor> mean_var_cpu(at::Tensor x);
std::vector<at::Tensor> mean_var_cuda(at::Tensor x);
std::vector<at::Tensor> mean_var_cuda_h(at::Tensor x);

at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
bool affine, float eps);
at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
bool affine, float eps);
at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
bool affine, float eps);

std::vector<at::Tensor> edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
bool affine, float eps);
std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
bool affine, float eps);
std::vector<at::Tensor> edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
bool affine, float eps);

at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
at::Tensor edz, at::Tensor eydz, bool affine, float eps);
at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
at::Tensor edz, at::Tensor eydz, bool affine, float eps);
at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
at::Tensor edz, at::Tensor eydz, bool affine, float eps);

void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope);
void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope);
void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope);

void elu_backward_cpu(at::Tensor z, at::Tensor dz);
void elu_backward_cuda(at::Tensor z, at::Tensor dz);

static void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) {
num = x.size(0);
Expand All @@ -51,6 +34,30 @@ static void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) {

#include "utils/cuda.cuh"

std::vector<at::Tensor> mean_var_cuda(at::Tensor x);
std::vector<at::Tensor> mean_var_cuda_h(at::Tensor x);

at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
bool affine, float eps);
at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
bool affine, float eps);

std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
bool affine, float eps);
std::vector<at::Tensor> edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
bool affine, float eps);

at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
at::Tensor edz, at::Tensor eydz, bool affine, float eps);
at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
at::Tensor edz, at::Tensor eydz, bool affine, float eps);

void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope);

void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope);

void elu_backward_cuda(at::Tensor z, at::Tensor dz);

template <typename T, typename Op>
__device__ T reduce(Op op, int plane, int N, int S) {
T sum = (T)0;
Expand Down Expand Up @@ -85,4 +92,16 @@ __device__ T reduce(Op op, int plane, int N, int S) {
// Everyone picks it up, should be broadcast into the whole gradInput
return shared[0];
}
#else
const auto mean_var_cuda = mean_var_cpu;
const auto mean_var_cuda_h = mean_var_cpu;
const auto forward_cuda = forward_cpu;
const auto forward_cuda_h = forward_cpu;
const auto edz_eydz_cuda = edz_eydz_cpu;
const auto edz_eydz_cuda_h = edz_eydz_cpu;
const auto backward_cuda = backward_cpu;
const auto backward_cuda_h = backward_cpu;
const auto leaky_relu_backward_cuda = leaky_relu_backward_cpu;
const auto leaky_relu_backward_cuda_h = leaky_relu_backward_cpu;
const auto elu_backward_cuda = elu_backward_cpu;
#endif
4 changes: 2 additions & 2 deletions networks/AugmentCE2P.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.nn import functional as F
# Note here we adopt the InplaceABNSync implementation from https://github.com/mapillary/inplace_abn
# By default, the InplaceABNSync module contains a BatchNorm Layer and a LeakyReLu layer
from modules import InPlaceABNSync
from ..modules import InPlaceABNSync

BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')

Expand Down Expand Up @@ -118,7 +118,7 @@ def forward(self, feats):

class ASPPModule(nn.Module):
"""
Reference:
Reference:
Chen, Liang-Chieh, et al. *"Rethinking Atrous Convolution for Semantic Image Segmentation."*
"""

Expand Down
2 changes: 1 addition & 1 deletion networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import absolute_import

from networks.AugmentCE2P import resnet101
from .AugmentCE2P import resnet101

__factory = {
'resnet101': resnet101,
Expand Down