Skip to content


[MPS] aten::erfinv metal kernel ops (pytorch#101507)
Browse files Browse the repository at this point in the history
I've added the implementation of erfinv using the algorithm from in order for the MPS based algorithm to match the CPU automatic test. This PR is using the new metal api calls from pytorch#100661

Testing shows MPS has a decent speed up (270x) compared to CPU on tensor size of 100 mil elements.
import torch
x = torch.arange(-1, 1, 1e-8) # default cpu tensor
#measure CPU compute time by calling torch.erfinv
time = %timeit -o -q -r 5 torch.erfinv(x)
cpu_time = time.average
print("CPU torch.erfinv time: ", cpu_time)
x ="mps")
# measure MPS compute time
time = %timeit -o -q -r 5 torch.erfinv(x)
mps_time = time.average
print("MPS torch.erfinv time: ", mps_time)
print(f"MPS torch.erfinv is {cpu_time/mps_time*100} percent faster than CPU torch.erfinv")

# compute MSE between MPS and CPU torch.erfinv
x ="cpu")
y_cpu = torch.erfinv(x)
x ="mps")
y_mps = torch.erfinv(x)
y_mps ="cpu")
mask = torch.isfinite(y_cpu) & torch.isfinite("cpu"))
y_mps = y_mps[mask]
y_cpu = y_cpu[mask]
x = x[mask]
print(f"length of y_mps: {len(y_mps)}, length of y_cpu: {len(y_cpu)}, length of x: {len(x)}")
mse = torch.square(y_cpu - y_mps).mean()
print("MSE between MPS and CPU torch.erfinv: ", mse)
diff = torch.abs(y_cpu - y_mps)
print("Largest difference")
print(f"x:  {x[torch.argmax(diff)]}, y_cpu: {y_cpu[torch.argmax(diff)]}, y_mps: {y_mps[torch.argmax(diff)]} , diff = {y_cpu[torch.argmax(diff)] - y_mps[torch.argmax(diff)]}")
CPU torch.erfinv time:  2.654937833400254
MPS torch.erfinv time:  0.009831255332002912
MPS torch.erfinv is 27005.07456822776 percent faster than CPU torch.erfinv
length of y_mps: 199999992, length of y_cpu: 199999992, length of x: 199999992
MSE between MPS and CPU torch.erfinv:  tensor(4.2339e-14)
Largest difference
x:  -0.9999980330467224, y_cpu: -3.363569736480713, y_mps: -3.3635685443878174 , diff = -1.1920928955078125e-06

Fixes #pytorch#86808

Pull Request resolved: pytorch#101507
Approved by:
  • Loading branch information
TaiPhamD authored and pytorchmergebot committed Jul 23, 2023
1 parent 12ea12d commit bba06ad
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 1 deletion.
43 changes: 43 additions & 0 deletions aten/src/ATen/native/mps/UnaryConstants.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#pragma once

#include <metal_stdlib>
using namespace metal;
constant float a[4] = {{0.886226899, -1.645349621, 0.914624893, -0.140543331}};
constant float b[4] = {{-2.118377725, 1.442710462, -0.329097515, 0.012229801}};
constant float c[4] = {{-1.970840454, -1.624906493, 3.429567803, 1.641345311}};
constant float d[2] = {{3.543889200, 1.637067800}};
kernel void erfinv_mps_kernel( device {0} *output [[buffer(0)]],
device {1} *input [[buffer(1)]],
uint index [[thread_position_in_grid]]) {{
float y = input[index];
float x, z, num, dem; /*working variables */
/* coefficients in rational expansion */
float y_abs = abs(y);
if(y_abs > 1.0f){{
output[index] = NAN;
if(y_abs == 1.0f){{
output[index] = copysign(INFINITY, y);
if(y_abs <= 0.7f) {{
z = y * y;
num = (((a[3]*z + a[2])*z + a[1])*z + a[0]);
dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + 1.0f);
x = y * num / dem;
z = sqrt(-1.0f*log((1.0-y_abs)/2.0));
num = ((c[3]*z + c[2])*z + c[1]) * z + c[0];
dem = (d[1]*z + d[0])*z + 1.0f;
x = copysign(num, y) / dem;
output[index] = x;
133 changes: 133 additions & 0 deletions aten/src/ATen/native/mps/operations/
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/UnaryConstants.h>
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#include <ATen/ops/erfinv_native.h>

#include <fmt/format.h>

namespace at::native {
const std::string& getMetalType(const c10::ScalarType& t) {
// Mapping from c10::ScalarType to integral type that can be used for unary ops
static std::unordered_map<c10::ScalarType, std::string> scalar_to_metal_type = {
{c10::ScalarType::Half, "half"},
{c10::ScalarType::Float, "float"},
{c10::ScalarType::Long, "long"},
{c10::ScalarType::Int, "int"},
{c10::ScalarType::Short, "short"},
{c10::ScalarType::Bool, "bool"},
{c10::ScalarType::Char, "int8_t"},
{c10::ScalarType::Byte, "uint8_t"},

auto it = scalar_to_metal_type.find(t);
TORCH_CHECK(it != scalar_to_metal_type.end(), "Unsupported type ", t);
return it->second;

const std::string& getMetalType(const c10::Scalar& s) {
return getMetalType(s.type());

const std::string& getMetalType(const Tensor& t) {
return getMetalType(t.scalar_type());

static id<MTLLibrary> compileUnaryOpsLibrary(id<MTLDevice> device, const std::string& t1, const std::string& t2) {
auto key = t1 + t2;
static std::unordered_map<std::string, id<MTLLibrary>> libMap;
auto it = libMap.find(key);
if (it != libMap.end()) {
return it->second;
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
auto rc =
[device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(UNARY_KERNEL_TEMPLATE, t1, t2).c_str()]
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
libMap[key] = rc;
return rc;

static id<MTLComputePipelineState> getCPLState(id<MTLDevice> device,
const std::string& t1,
const std::string& t2,
const std::string& fname) {
auto key = t1 + t2 + fname;
static std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
auto it = cplMap.find(key);
if (it != cplMap.end()) {
return it->second;
NSError* error = nil;
auto library = compileUnaryOpsLibrary(device, t1, t2);
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]];
TORCH_CHECK(func != nil, "Can't get function ", fname);
auto rc = [device newComputePipelineStateWithFunction:func error:&error];
rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
cplMap[key] = rc;
return rc;

TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) {
// handle erfinv ops using metal kernel
// erfinv algorithm ported from aten/src/ATen/native/Math.h

TORCH_CHECK(self.scalar_type() != ScalarType::Double, "MPS does not support erfinv op with scalar type: Double");

Tensor outputTensor = output_;
bool needs_output_copy = false;
uint32_t length = output_.numel();
if (length == 0) {
using namespace mps;
@autoreleasepool {
Tensor inputTensor = self;
id<MTLDevice> device = MPSDevice::getInstance()->device();
id<MTLComputePipelineState> cplState =
getCPLState(device, getMetalType(outputTensor), getMetalType(self), "erfinv_mps_kernel");

if (!self.is_contiguous()) {
inputTensor = inputTensor.contiguous();
outputTensor = outputTensor.contiguous();
needs_output_copy = true;

MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
id<MTLBuffer> outBuf = getMTLBufferStorage(outputTensor);
id<MTLBuffer> inputBuf = getMTLBufferStorage(inputTensor);

getMPSProfiler().beginProfileKernel(cplState, "erf_inv", {self});

[computeEncoder setComputePipelineState:cplState];
[computeEncoder setBuffer:outBuf offset:0 atIndex:0];
[computeEncoder setBuffer:inputBuf offset:0 atIndex:1];

MTLSize gridSize = MTLSizeMake(length, 1, 1);
uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup];
NSUInteger threadsPerGroupSize = std::min(maxThreadsPerGroup, length);
MTLSize threadGroupSize = MTLSizeMake(threadsPerGroupSize, 1, 1);
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];

if (needs_output_copy) {
} // namespace at::native
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9124,6 +9124,7 @@
structured_inherits: TensorIteratorBase
CPU, CUDA: erfinv_out
MPS: erfinv_out_mps
SparseCPU, SparseCUDA: erfinv_sparse_out
SparseCsrCPU, SparseCsrCUDA: erfinv_sparse_csr_out
tags: pointwise
Expand Down
3 changes: 2 additions & 1 deletion test/
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,6 @@ def mps_ops_modifier(ops):
'cumprod': None,
'digamma': None,
'erfc': None,
'erfinv': None,
'frexp': None,
'gcd': None,
'geqrf': None,
Expand Down Expand Up @@ -7548,6 +7547,8 @@ def helper(shape, op):
helper((2, 8, 3, 5), torch.expm1)
helper((2, 8, 3, 5), torch.log)
helper((2, 8, 3, 5), torch.cos)
helper((2, 8, 3, 5), torch.erfinv)

def test_non_dense_in_storage_unary_ops(self):
def helper(op):
Expand Down

0 comments on commit bba06ad

Please sign in to comment.