Skip to content

Commit

Permalink
Support float16 in conv2d demo
Browse files Browse the repository at this point in the history
  • Loading branch information
reillyeon committed Jul 26, 2024
1 parent af15d0c commit 96b97db
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 76 deletions.
62 changes: 62 additions & 0 deletions half-floats.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@

// ref: http://stackoverflow.com/questions/32633585/how-do-you-convert-to-half-floats-in-javascript
function toHalf(value) {
const floatView = new Float32Array(1);
const int32View = new Int32Array(floatView.buffer);

// This method is faster than the OpenEXR implementation (very often
// used, eg. in Ogre), with the additional benefit of rounding, inspired
// by James Tursa's half-precision code.

floatView[0] = value;
const x = int32View[0];

let bits = (x >> 16) & 0x8000; // Get the sign
let m = (x >> 12) & 0x07ff; // Keep one extra bit for rounding
const e = (x >> 23) & 0xff; // Using int is faster here

// If zero, or denormal, or exponent underflows too much for a denormal
// half, return signed zero.
if (e < 103) {
return bits;
}

// If NaN, return NaN. If Inf or exponent overflow, return Inf.
if (e > 142) {
bits |= 0x7c00;
// If exponent was 0xff and one mantissa bit was set, it means NaN,
// not Inf, so make sure we set one mantissa bit too.
bits |= ((e == 255) ? 0 : 1) && (x & 0x007fffff);
return bits;
}

// If exponent underflows but not too much, return a denormal
if (e < 113) {
m |= 0x0800;
// Extra rounding may overflow and set mantissa to 0 and exponent
// to 1, which is OK.
bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1);
return bits;
}

bits |= ((e - 112) << 10) | (m >> 1);
// Extra rounding. An overflow will set mantissa to 0 and increment
// the exponent, which is OK.
bits += m & 1;
return bits;
}

// ref: https://stackoverflow.com/questions/5678432/decompressing-half-precision-floats-in-javascript
function fromHalf(h) {
const s = (h & 0x8000) >> 15 ? -1 : 1;
const e = (h & 0x7C00) >> 10;
const f = h & 0x03FF;

if (e == 0) {
return s * Math.pow(2, -14) * (f / Math.pow(2, 10));
} else if (e == 0x1F) {
return f ? NaN : (s * Infinity);
}

return s * Math.pow(2, e - 15) * (1 + (f / Math.pow(2, 10)));
}
63 changes: 1 addition & 62 deletions webnn-add.html
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<html>
<head>
<title>WebNN Simple Example</title>
<script src="half-floats.js"></script>
</head>
<body>
<p>
Expand Down Expand Up @@ -93,68 +94,6 @@
}
}

// ref: http://stackoverflow.com/questions/32633585/how-do-you-convert-to-half-floats-in-javascript
function toHalf(value) {
const floatView = new Float32Array(1);
const int32View = new Int32Array(floatView.buffer);

// This method is faster than the OpenEXR implementation (very often
// used, eg. in Ogre), with the additional benefit of rounding, inspired
// by James Tursa's half-precision code.

floatView[0] = value;
const x = int32View[0];

let bits = (x >> 16) & 0x8000; // Get the sign
let m = (x >> 12) & 0x07ff; // Keep one extra bit for rounding
const e = (x >> 23) & 0xff; // Using int is faster here

// If zero, or denormal, or exponent underflows too much for a denormal
// half, return signed zero.
if (e < 103) {
return bits;
}

// If NaN, return NaN. If Inf or exponent overflow, return Inf.
if (e > 142) {
bits |= 0x7c00;
// If exponent was 0xff and one mantissa bit was set, it means NaN,
// not Inf, so make sure we set one mantissa bit too.
bits |= ((e == 255) ? 0 : 1) && (x & 0x007fffff);
return bits;
}

// If exponent underflows but not too much, return a denormal
if (e < 113) {
m |= 0x0800;
// Extra rounding may overflow and set mantissa to 0 and exponent
// to 1, which is OK.
bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1);
return bits;
}

bits |= ((e - 112) << 10) | (m >> 1);
// Extra rounding. An overflow will set mantissa to 0 and increment
// the exponent, which is OK.
bits += m & 1;
return bits;
}

// ref: https://stackoverflow.com/questions/5678432/decompressing-half-precision-floats-in-javascript
function fromHalf(h) {
const s = (h & 0x8000) >> 15 ? -1 : 1;
const e = (h & 0x7C00) >> 10;
const f = h & 0x03FF;

if (e == 0) {
return s * Math.pow(2, -14) * (f / Math.pow(2, 10));
} else if (e == 0x1F) {
return f ? NaN : (s * Infinity);
}

return s * Math.pow(2, e - 15) * (1 + (f / Math.pow(2, 10)));
}

function maybeEncodeInput(input) {
if (dataTypeOption.value == 'float16') {
return toHalf(input);
Expand Down
78 changes: 64 additions & 14 deletions webnn-conv2d.html
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<html>
<head>
<title>WebNN Conv2D</title>
<script src="half-floats.js"></script>
</head>
<body>
<p>
Expand All @@ -11,6 +12,13 @@
<option value="npu">NPU</option>
</select>
</p>
<p>
Data type:
<select id="dataType">
<option value="float16">float16</option>
<option selected value="float32">float32</option>
</select>
</p>
<p>
<label for="filterType">Filter type:</label>
<select id="filterType" disabled>
Expand Down Expand Up @@ -46,7 +54,7 @@
<pre id="status"></pre>
<script>
const channels = 4;
const inputDescriptor = {dataType: 'float32', dimensions: [1, 500, 500, channels]};
const inputShape = [1, 500, 500, channels];

let context;
let graph;
Expand All @@ -56,11 +64,33 @@
let outputBuffer;
let outputDescriptor;

function getTypedArrayConstructor() {
switch (dataTypeOption.value) {
case 'float16':
return Uint16Array;
case 'float32':
return Float32Array;
}
}

function maybeConvertFilterData(float32Data) {
switch (dataTypeOption.value) {
case 'float16':
const float16Data = new Uint16Array(float32Data.length);
for (let i = 0; i < float32Data.length; ++i) {
float16Data[i] = toHalf(float32Data[i]);
}
return float16Data;
case 'float32':
return float32Data;
}
}

async function createBlurGraph(context) {
const inputLayout = inputLayoutElement.value;
const builder = new MLGraphBuilder(context);

let input = builder.input('input', inputDescriptor);
let input = builder.input('input', {dataType: dataTypeOption.value, dimensions: inputShape});
if (inputLayout == 'nchw') {
input = builder.transpose(input, {permutation: [0, 3, 1, 2]})
}
Expand All @@ -76,10 +106,11 @@

// A simple blur filter is easy because the layout doesn't matter, the
// elements simply have to sum to 1.
const filterData = new Float32Array(filterHeight * filterWidth * channels);
let filterData = new Float32Array(filterHeight * filterWidth * channels);
filterData.fill(1 / (filterHeight * filterWidth));
filterData = maybeConvertFilterData(filterData);
const filter = builder.constant(
{dataType: 'float32', dimensions: filterShape}, filterData);
{dataType: dataTypeOption.value, dimensions: filterShape}, filterData);

let output = builder.conv2d(input, filter, {
inputLayout, filterLayout,
Expand All @@ -97,7 +128,7 @@
const inputLayout = inputLayoutElement.value;
const builder = new MLGraphBuilder(context);

let input = builder.input('input', inputDescriptor);
let input = builder.input('input', {dataType: dataTypeOption.value, dimensions: inputShape});
if (inputLayout == 'nchw') {
input = builder.transpose(input, {permutation: [0, 3, 1, 2]})
}
Expand All @@ -109,13 +140,13 @@
[channels, channels, 1, 1] : [channels, 1, 1, channels]

// Mix the RGB channels but not the alpha channel.
const filterData = Float32Array.of(
const filterData = maybeConvertFilterData(Float32Array.of(
1/3, 1/3, 1/3, 0,
1/3, 1/3, 1/3, 0,
1/3, 1/3, 1/3, 0,
0, 0, 0, 1);
0, 0, 0, 1));
const filter = builder.constant(
{dataType: 'float32', dimensions: filterShape}, filterData);
{dataType: dataTypeOption.value, dimensions: filterShape}, filterData);

let output = builder.conv2d(input, filter, {inputLayout, filterLayout});
if (inputLayout == 'nchw') {
Expand All @@ -127,17 +158,32 @@
}

function imageDataToTensor(imageData) {
const tensor = new Float32Array(imageData.data.length);
const typedArray = getTypedArrayConstructor(dataTypeOption.value);
const tensor = new typedArray(imageData.data.length);
for (let i = 0; i < imageData.data.length; ++i) {
tensor[i] = imageData.data[i] / 256;
switch (dataTypeOption.value) {
case 'float16':
tensor[i] = toHalf(imageData.data[i] / 256);
break;
case 'float32':
tensor[i] = imageData.data[i] / 256;
break;
}
}
return tensor;
}

function tensorToImageData(tensor, width, height) {
const imageData = new ImageData(width, height);
for (let i = 0; i < tensor.length; ++i) {
imageData.data[i] = tensor[i] * 256;
switch (dataTypeOption.value) {
case 'float16':
imageData.data[i] = fromHalf(tensor[i]) * 256;
break;
case 'float32':
imageData.data[i] = tensor[i] * 256;
break;
}
}
return imageData;
}
Expand All @@ -151,15 +197,16 @@

try {
outputCtx.clearRect(0, 0, outputCanvas.width, outputCanvas.height);
const typedArray = getTypedArrayConstructor(dataTypeOption.value);

performance.mark('compute-start');
if (dispatchCheckbox.checked) {
context.dispatch(graph, {input: inputBuffer}, {output: outputBuffer});
const buffer = await context.readBuffer(outputBuffer);
outputs.output = new Float32Array(buffer);
outputs.output = new typedArray(buffer);
} else {
const outputLength = outputDescriptor.dimensions.reduce((acc, value) => acc * value, 1);
outputs.output = new Float32Array(outputLength);
outputs.output = new typedArray(outputLength);
({inputs, outputs} = await context.compute(graph, inputs, outputs));
}
performance.mark('compute-end');
Expand Down Expand Up @@ -203,7 +250,7 @@
}

function createMLBuffers() {
inputBuffer = context.createBuffer(inputDescriptor);
inputBuffer = context.createBuffer({dataType: dataTypeOption.value, dimensions: inputShape});
context.writeBuffer(inputBuffer, inputs.input);

outputBuffer = context.createBuffer(outputDescriptor);
Expand Down Expand Up @@ -259,6 +306,9 @@
const deviceTypeElement = document.getElementById('deviceType');
deviceTypeElement.onchange = contextOptionsChanged

const dataTypeOption = document.getElementById('dataType');
dataTypeOption.onchange = graphOptionsChanged;

const filterTypeElement = document.getElementById('filterType');
filterTypeElement.onchange = () => {
blurRadiusElement.disabled = filterTypeElement.value != 'blur';
Expand Down

0 comments on commit 96b97db

Please sign in to comment.