Skip to content

Commit

Permalink
[MPS] aten::erfinv metal kernel ops (pytorch#101507)
Browse files Browse the repository at this point in the history
I've added the implementation of erfinv using the algorithm from https://github.com/pytorch/pytorch/blob/4154c8ea159fdaecc71ee9af820ac956193c875b/aten/src/ATen/native/Math.h#L152 in order for the MPS based algorithm to match the CPU automatic test. This PR is using the new metal api calls from pytorch#100661

Testing shows MPS has a decent speed up (270x) compared to CPU on tensor size of 100 mil elements.
```
import torch
x = torch.arange(-1, 1, 1e-8) # default cpu tensor
#measure CPU compute time by calling torch.erfinv
time = %timeit -o -q -r 5 torch.erfinv(x)
cpu_time = time.average
print("CPU torch.erfinv time: ", cpu_time)
x = x.to("mps")
# measure MPS compute time
time = %timeit -o -q -r 5 torch.erfinv(x)
mps_time = time.average
print("MPS torch.erfinv time: ", mps_time)
print(f"MPS torch.erfinv is {cpu_time/mps_time*100} percent faster than CPU torch.erfinv")

# compute MSE between MPS and CPU torch.erfinv
x = x.to("cpu")
y_cpu = torch.erfinv(x)
x = x.to("mps")
y_mps = torch.erfinv(x)
y_mps = y_mps.to("cpu")
mask = torch.isfinite(y_cpu) & torch.isfinite(y_mps.to("cpu"))
y_mps = y_mps[mask]
y_cpu = y_cpu[mask]
x = x[mask]
print(f"length of y_mps: {len(y_mps)}, length of y_cpu: {len(y_cpu)}, length of x: {len(x)}")
mse = torch.square(y_cpu - y_mps).mean()
print("MSE between MPS and CPU torch.erfinv: ", mse)
diff = torch.abs(y_cpu - y_mps)
print("Largest difference")
print(f"x:  {x[torch.argmax(diff)]}, y_cpu: {y_cpu[torch.argmax(diff)]}, y_mps: {y_mps[torch.argmax(diff)]} , diff = {y_cpu[torch.argmax(diff)] - y_mps[torch.argmax(diff)]}")
```
CPU torch.erfinv time:  2.654937833400254
MPS torch.erfinv time:  0.009831255332002912
MPS torch.erfinv is 27005.07456822776 percent faster than CPU torch.erfinv
length of y_mps: 199999992, length of y_cpu: 199999992, length of x: 199999992
MSE between MPS and CPU torch.erfinv:  tensor(4.2339e-14)
Largest difference
x:  -0.9999980330467224, y_cpu: -3.363569736480713, y_mps: -3.3635685443878174 , diff = -1.1920928955078125e-06

Fixes #pytorch#86808

Pull Request resolved: pytorch#101507
Approved by: https://github.com/kulinseth
  • Loading branch information
TaiPhamD authored and pytorchmergebot committed Jul 23, 2023
1 parent 12ea12d commit bba06ad
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 1 deletion.
43 changes: 43 additions & 0 deletions aten/src/ATen/native/mps/UnaryConstants.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#pragma once

const char* UNARY_KERNEL_TEMPLATE = R"METAL(
#include <metal_stdlib>
using namespace metal;
constant float a[4] = {{0.886226899, -1.645349621, 0.914624893, -0.140543331}};
constant float b[4] = {{-2.118377725, 1.442710462, -0.329097515, 0.012229801}};
constant float c[4] = {{-1.970840454, -1.624906493, 3.429567803, 1.641345311}};
constant float d[2] = {{3.543889200, 1.637067800}};
kernel void erfinv_mps_kernel( device {0} *output [[buffer(0)]],
device {1} *input [[buffer(1)]],
uint index [[thread_position_in_grid]]) {{
float y = input[index];
float x, z, num, dem; /*working variables */
/* coefficients in rational expansion */
float y_abs = abs(y);
if(y_abs > 1.0f){{
output[index] = NAN;
return;
}}
if(y_abs == 1.0f){{
output[index] = copysign(INFINITY, y);
return;
}}
if(y_abs <= 0.7f) {{
z = y * y;
num = (((a[3]*z + a[2])*z + a[1])*z + a[0]);
dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + 1.0f);
x = y * num / dem;
}}
else{{
z = sqrt(-1.0f*log((1.0-y_abs)/2.0));
num = ((c[3]*z + c[2])*z + c[1]) * z + c[0];
dem = (d[1]*z + d[0])*z + 1.0f;
x = copysign(num, y) / dem;
}}
output[index] = x;
}})METAL";
133 changes: 133 additions & 0 deletions aten/src/ATen/native/mps/operations/UnaryKernel.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/UnaryConstants.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/erfinv_native.h>
#endif

#include <fmt/format.h>

namespace at::native {
const std::string& getMetalType(const c10::ScalarType& t) {
// Mapping from c10::ScalarType to integral type that can be used for unary ops
static std::unordered_map<c10::ScalarType, std::string> scalar_to_metal_type = {
{c10::ScalarType::Half, "half"},
{c10::ScalarType::Float, "float"},
{c10::ScalarType::Long, "long"},
{c10::ScalarType::Int, "int"},
{c10::ScalarType::Short, "short"},
{c10::ScalarType::Bool, "bool"},
{c10::ScalarType::Char, "int8_t"},
{c10::ScalarType::Byte, "uint8_t"},
};

auto it = scalar_to_metal_type.find(t);
TORCH_CHECK(it != scalar_to_metal_type.end(), "Unsupported type ", t);
return it->second;
}

const std::string& getMetalType(const c10::Scalar& s) {
return getMetalType(s.type());
}

const std::string& getMetalType(const Tensor& t) {
return getMetalType(t.scalar_type());
}

static id<MTLLibrary> compileUnaryOpsLibrary(id<MTLDevice> device, const std::string& t1, const std::string& t2) {
auto key = t1 + t2;
static std::unordered_map<std::string, id<MTLLibrary>> libMap;
auto it = libMap.find(key);
if (it != libMap.end()) {
return it->second;
}
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
auto rc =
[device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(UNARY_KERNEL_TEMPLATE, t1, t2).c_str()]
options:options
error:&error];
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
libMap[key] = rc;
return rc;
}

static id<MTLComputePipelineState> getCPLState(id<MTLDevice> device,
const std::string& t1,
const std::string& t2,
const std::string& fname) {
auto key = t1 + t2 + fname;
static std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
auto it = cplMap.find(key);
if (it != cplMap.end()) {
return it->second;
}
NSError* error = nil;
auto library = compileUnaryOpsLibrary(device, t1, t2);
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]];
TORCH_CHECK(func != nil, "Can't get function ", fname);
auto rc = [device newComputePipelineStateWithFunction:func error:&error];
TORCH_CHECK(
rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
cplMap[key] = rc;
return rc;
}

TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) {
// handle erfinv ops using metal kernel
// erfinv algorithm ported from aten/src/ATen/native/Math.h
// https://github.com/pytorch/pytorch/blob/4154c8ea159fdaecc71ee9af820ac956193c875b/aten/src/ATen/native/Math.h#L152

TORCH_CHECK(self.scalar_type() != ScalarType::Double, "MPS does not support erfinv op with scalar type: Double");

Tensor outputTensor = output_;
bool needs_output_copy = false;
uint32_t length = output_.numel();
if (length == 0) {
return;
}
using namespace mps;
@autoreleasepool {
Tensor inputTensor = self;
id<MTLDevice> device = MPSDevice::getInstance()->device();
id<MTLComputePipelineState> cplState =
getCPLState(device, getMetalType(outputTensor), getMetalType(self), "erfinv_mps_kernel");

if (!self.is_contiguous()) {
inputTensor = inputTensor.contiguous();
outputTensor = outputTensor.contiguous();
needs_output_copy = true;
}

MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
id<MTLBuffer> outBuf = getMTLBufferStorage(outputTensor);
id<MTLBuffer> inputBuf = getMTLBufferStorage(inputTensor);

getMPSProfiler().beginProfileKernel(cplState, "erf_inv", {self});

[computeEncoder setComputePipelineState:cplState];
[computeEncoder setBuffer:outBuf offset:0 atIndex:0];
[computeEncoder setBuffer:inputBuf offset:0 atIndex:1];

MTLSize gridSize = MTLSizeMake(length, 1, 1);
uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup];
NSUInteger threadsPerGroupSize = std::min(maxThreadsPerGroup, length);
MTLSize threadGroupSize = MTLSizeMake(threadsPerGroupSize, 1, 1);
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];

getMPSProfiler().endProfileKernel(cplState);
});
}
if (needs_output_copy) {
output_.copy_(outputTensor);
}
}
} // namespace at::native
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9124,6 +9124,7 @@
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: erfinv_out
MPS: erfinv_out_mps
SparseCPU, SparseCUDA: erfinv_sparse_out
SparseCsrCPU, SparseCsrCUDA: erfinv_sparse_csr_out
tags: pointwise
Expand Down
3 changes: 2 additions & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,6 @@ def mps_ops_modifier(ops):
'cumprod': None,
'digamma': None,
'erfc': None,
'erfinv': None,
'frexp': None,
'gcd': None,
'geqrf': None,
Expand Down Expand Up @@ -7548,6 +7547,8 @@ def helper(shape, op):
helper((2, 8, 3, 5), torch.expm1)
helper((2, 8, 3, 5), torch.log)
helper((2, 8, 3, 5), torch.cos)
helper((2, 8, 3, 5), torch.erfinv)


def test_non_dense_in_storage_unary_ops(self):
def helper(op):
Expand Down

0 comments on commit bba06ad

Please sign in to comment.