forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MPS] aten::erfinv metal kernel ops (pytorch#101507)
I've added the implementation of erfinv using the algorithm from https://github.com/pytorch/pytorch/blob/4154c8ea159fdaecc71ee9af820ac956193c875b/aten/src/ATen/native/Math.h#L152 in order for the MPS based algorithm to match the CPU automatic test. This PR is using the new metal api calls from pytorch#100661 Testing shows MPS has a decent speed up (270x) compared to CPU on tensor size of 100 mil elements. ``` import torch x = torch.arange(-1, 1, 1e-8) # default cpu tensor #measure CPU compute time by calling torch.erfinv time = %timeit -o -q -r 5 torch.erfinv(x) cpu_time = time.average print("CPU torch.erfinv time: ", cpu_time) x = x.to("mps") # measure MPS compute time time = %timeit -o -q -r 5 torch.erfinv(x) mps_time = time.average print("MPS torch.erfinv time: ", mps_time) print(f"MPS torch.erfinv is {cpu_time/mps_time*100} percent faster than CPU torch.erfinv") # compute MSE between MPS and CPU torch.erfinv x = x.to("cpu") y_cpu = torch.erfinv(x) x = x.to("mps") y_mps = torch.erfinv(x) y_mps = y_mps.to("cpu") mask = torch.isfinite(y_cpu) & torch.isfinite(y_mps.to("cpu")) y_mps = y_mps[mask] y_cpu = y_cpu[mask] x = x[mask] print(f"length of y_mps: {len(y_mps)}, length of y_cpu: {len(y_cpu)}, length of x: {len(x)}") mse = torch.square(y_cpu - y_mps).mean() print("MSE between MPS and CPU torch.erfinv: ", mse) diff = torch.abs(y_cpu - y_mps) print("Largest difference") print(f"x: {x[torch.argmax(diff)]}, y_cpu: {y_cpu[torch.argmax(diff)]}, y_mps: {y_mps[torch.argmax(diff)]} , diff = {y_cpu[torch.argmax(diff)] - y_mps[torch.argmax(diff)]}") ``` CPU torch.erfinv time: 2.654937833400254 MPS torch.erfinv time: 0.009831255332002912 MPS torch.erfinv is 27005.07456822776 percent faster than CPU torch.erfinv length of y_mps: 199999992, length of y_cpu: 199999992, length of x: 199999992 MSE between MPS and CPU torch.erfinv: tensor(4.2339e-14) Largest difference x: -0.9999980330467224, y_cpu: -3.363569736480713, y_mps: -3.3635685443878174 , diff = -1.1920928955078125e-06 Fixes #pytorch#86808 Pull Request resolved: pytorch#101507 Approved by: https://github.com/kulinseth
- Loading branch information
1 parent
12ea12d
commit bba06ad
Showing
4 changed files
with
179 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#pragma once | ||
|
||
const char* UNARY_KERNEL_TEMPLATE = R"METAL( | ||
#include <metal_stdlib> | ||
using namespace metal; | ||
constant float a[4] = {{0.886226899, -1.645349621, 0.914624893, -0.140543331}}; | ||
constant float b[4] = {{-2.118377725, 1.442710462, -0.329097515, 0.012229801}}; | ||
constant float c[4] = {{-1.970840454, -1.624906493, 3.429567803, 1.641345311}}; | ||
constant float d[2] = {{3.543889200, 1.637067800}}; | ||
kernel void erfinv_mps_kernel( device {0} *output [[buffer(0)]], | ||
device {1} *input [[buffer(1)]], | ||
uint index [[thread_position_in_grid]]) {{ | ||
float y = input[index]; | ||
float x, z, num, dem; /*working variables */ | ||
/* coefficients in rational expansion */ | ||
float y_abs = abs(y); | ||
if(y_abs > 1.0f){{ | ||
output[index] = NAN; | ||
return; | ||
}} | ||
if(y_abs == 1.0f){{ | ||
output[index] = copysign(INFINITY, y); | ||
return; | ||
}} | ||
if(y_abs <= 0.7f) {{ | ||
z = y * y; | ||
num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); | ||
dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + 1.0f); | ||
x = y * num / dem; | ||
}} | ||
else{{ | ||
z = sqrt(-1.0f*log((1.0-y_abs)/2.0)); | ||
num = ((c[3]*z + c[2])*z + c[1]) * z + c[0]; | ||
dem = (d[1]*z + d[0])*z + 1.0f; | ||
x = copysign(num, y) / dem; | ||
}} | ||
output[index] = x; | ||
}})METAL"; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||
#include <ATen/mps/MPSProfiler.h> | ||
#include <ATen/native/UnaryOps.h> | ||
#include <ATen/native/mps/OperationUtils.h> | ||
#include <ATen/native/mps/UnaryConstants.h> | ||
#ifndef AT_PER_OPERATOR_HEADERS | ||
#include <ATen/Functions.h> | ||
#include <ATen/NativeFunctions.h> | ||
#else | ||
#include <ATen/ops/erfinv_native.h> | ||
#endif | ||
|
||
#include <fmt/format.h> | ||
|
||
namespace at::native { | ||
const std::string& getMetalType(const c10::ScalarType& t) { | ||
// Mapping from c10::ScalarType to integral type that can be used for unary ops | ||
static std::unordered_map<c10::ScalarType, std::string> scalar_to_metal_type = { | ||
{c10::ScalarType::Half, "half"}, | ||
{c10::ScalarType::Float, "float"}, | ||
{c10::ScalarType::Long, "long"}, | ||
{c10::ScalarType::Int, "int"}, | ||
{c10::ScalarType::Short, "short"}, | ||
{c10::ScalarType::Bool, "bool"}, | ||
{c10::ScalarType::Char, "int8_t"}, | ||
{c10::ScalarType::Byte, "uint8_t"}, | ||
}; | ||
|
||
auto it = scalar_to_metal_type.find(t); | ||
TORCH_CHECK(it != scalar_to_metal_type.end(), "Unsupported type ", t); | ||
return it->second; | ||
} | ||
|
||
const std::string& getMetalType(const c10::Scalar& s) { | ||
return getMetalType(s.type()); | ||
} | ||
|
||
const std::string& getMetalType(const Tensor& t) { | ||
return getMetalType(t.scalar_type()); | ||
} | ||
|
||
static id<MTLLibrary> compileUnaryOpsLibrary(id<MTLDevice> device, const std::string& t1, const std::string& t2) { | ||
auto key = t1 + t2; | ||
static std::unordered_map<std::string, id<MTLLibrary>> libMap; | ||
auto it = libMap.find(key); | ||
if (it != libMap.end()) { | ||
return it->second; | ||
} | ||
NSError* error = nil; | ||
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; | ||
[options setLanguageVersion:MTLLanguageVersion2_3]; | ||
auto rc = | ||
[device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(UNARY_KERNEL_TEMPLATE, t1, t2).c_str()] | ||
options:options | ||
error:&error]; | ||
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]); | ||
libMap[key] = rc; | ||
return rc; | ||
} | ||
|
||
static id<MTLComputePipelineState> getCPLState(id<MTLDevice> device, | ||
const std::string& t1, | ||
const std::string& t2, | ||
const std::string& fname) { | ||
auto key = t1 + t2 + fname; | ||
static std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap; | ||
auto it = cplMap.find(key); | ||
if (it != cplMap.end()) { | ||
return it->second; | ||
} | ||
NSError* error = nil; | ||
auto library = compileUnaryOpsLibrary(device, t1, t2); | ||
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]]; | ||
TORCH_CHECK(func != nil, "Can't get function ", fname); | ||
auto rc = [device newComputePipelineStateWithFunction:func error:&error]; | ||
TORCH_CHECK( | ||
rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]); | ||
cplMap[key] = rc; | ||
return rc; | ||
} | ||
|
||
TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) { | ||
// handle erfinv ops using metal kernel | ||
// erfinv algorithm ported from aten/src/ATen/native/Math.h | ||
// https://github.com/pytorch/pytorch/blob/4154c8ea159fdaecc71ee9af820ac956193c875b/aten/src/ATen/native/Math.h#L152 | ||
|
||
TORCH_CHECK(self.scalar_type() != ScalarType::Double, "MPS does not support erfinv op with scalar type: Double"); | ||
|
||
Tensor outputTensor = output_; | ||
bool needs_output_copy = false; | ||
uint32_t length = output_.numel(); | ||
if (length == 0) { | ||
return; | ||
} | ||
using namespace mps; | ||
@autoreleasepool { | ||
Tensor inputTensor = self; | ||
id<MTLDevice> device = MPSDevice::getInstance()->device(); | ||
id<MTLComputePipelineState> cplState = | ||
getCPLState(device, getMetalType(outputTensor), getMetalType(self), "erfinv_mps_kernel"); | ||
|
||
if (!self.is_contiguous()) { | ||
inputTensor = inputTensor.contiguous(); | ||
outputTensor = outputTensor.contiguous(); | ||
needs_output_copy = true; | ||
} | ||
|
||
MPSStream* mpsStream = getCurrentMPSStream(); | ||
dispatch_sync(mpsStream->queue(), ^() { | ||
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder(); | ||
id<MTLBuffer> outBuf = getMTLBufferStorage(outputTensor); | ||
id<MTLBuffer> inputBuf = getMTLBufferStorage(inputTensor); | ||
|
||
getMPSProfiler().beginProfileKernel(cplState, "erf_inv", {self}); | ||
|
||
[computeEncoder setComputePipelineState:cplState]; | ||
[computeEncoder setBuffer:outBuf offset:0 atIndex:0]; | ||
[computeEncoder setBuffer:inputBuf offset:0 atIndex:1]; | ||
|
||
MTLSize gridSize = MTLSizeMake(length, 1, 1); | ||
uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup]; | ||
NSUInteger threadsPerGroupSize = std::min(maxThreadsPerGroup, length); | ||
MTLSize threadGroupSize = MTLSizeMake(threadsPerGroupSize, 1, 1); | ||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; | ||
|
||
getMPSProfiler().endProfileKernel(cplState); | ||
}); | ||
} | ||
if (needs_output_copy) { | ||
output_.copy_(outputTensor); | ||
} | ||
} | ||
} // namespace at::native |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters