Skip to content

Commit

Permalink
Add option to run WebNN conv2d demo with MLBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
reillyeon committed Jul 15, 2024
1 parent 591e88b commit 971ab5f
Showing 1 changed file with 67 additions and 26 deletions.
93 changes: 67 additions & 26 deletions webnn-conv2d.html
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
<option selected value="nhwc">Channels last (NHWC)</option>
</select>
</p>
<p>
<label for="dispatch">Use <code>MLBuffer</code>:</label>
<input id="dispatch" type="checkbox">
</p>
<table>
<tr><th>Input</th><th>Output</th></tr>
<tr>
Expand All @@ -42,19 +46,21 @@
<pre id="status"></pre>
<script>
const channels = 4;
let inputData;
const inputDescriptor = {dataType: 'float32', dimensions: [1, 500, 500, channels]};

let context;
let graph;
let outputWidth;
let outputHeight;
let inputs = {input: null};
let outputs = {output: null};
let inputBuffer;
let outputBuffer;
let outputDescriptor;

async function createBlurGraph(context) {
const inputLayout = inputLayoutElement.value;
const builder = new MLGraphBuilder(context);

let input = builder.input(
'input', {dataType: 'float32', dimensions: [1, 500, 500, channels]});
let input = builder.input('input', inputDescriptor);
if (inputLayout == 'nchw') {
input = builder.transpose(input, {permutation: [0, 3, 1, 2]})
}
Expand Down Expand Up @@ -83,19 +89,15 @@
output = builder.transpose(output, {permutation: [0, 2, 3, 1]})
}

return {
graph: await builder.build({'output': output}),
outputHeight: output.shape()[1],
outputWidth: output.shape()[2],
}
graph = await builder.build({'output': output}),
outputDescriptor = {dataType: output.dataType(), dimensions: output.shape()};
}

async function createGrayscaleGraph(context) {
const inputLayout = inputLayoutElement.value;
const builder = new MLGraphBuilder(context);

let input = builder.input(
'input', {dataType: 'float32', dimensions: [1, 500, 500, channels]});
let input = builder.input('input', inputDescriptor);
if (inputLayout == 'nchw') {
input = builder.transpose(input, {permutation: [0, 3, 1, 2]})
}
Expand All @@ -120,11 +122,8 @@
output = builder.transpose(output, {permutation: [0, 2, 3, 1]})
}

return {
graph: await builder.build({'output': output}),
outputHeight: output.shape()[1],
outputWidth: output.shape()[2],
}
graph = await builder.build({'output': output}),
outputDescriptor = {dataType: output.dataType(), dimensions: output.shape()};
}

function imageDataToTensor(imageData) {
Expand Down Expand Up @@ -153,14 +152,19 @@
try {
outputCtx.clearRect(0, 0, outputCanvas.width, outputCanvas.height);

const input = imageDataToTensor(inputData);
const output = new Float32Array(outputWidth * outputHeight * channels);

performance.mark('compute-start');
const {inputs, outputs} = await context.compute(graph, {input}, {output});
if (dispatchCheckbox.checked) {
context.dispatch(graph, {input: inputBuffer}, {output: outputBuffer});
const buffer = await context.readBuffer(outputBuffer);
outputs.output = new Float32Array(buffer);
} else {
const outputLength = outputDescriptor.dimensions.reduce((acc, value) => acc * value, 1);
outputs.output = new Float32Array(outputLength);
({inputs, outputs} = await context.compute(graph, inputs, outputs));
}
performance.mark('compute-end');

const outputData = tensorToImageData(outputs.output, outputWidth, outputHeight);
const outputData = tensorToImageData(outputs.output, outputDescriptor.dimensions[1], outputDescriptor.dimensions[2]);
outputCtx.putImageData(outputData, 0, 0);

const computeMeasure = performance.measure('compute-measure', 'compute-start', 'compute-end');
Expand All @@ -177,11 +181,17 @@
context = await navigator.ml.createContext({deviceType: deviceTypeElement.value});

performance.mark('build-start');
({graph, outputWidth, outputHeight} =
filterTypeElement.value == 'blur' ?
await createBlurGraph(context) : await createGrayscaleGraph(context));
if (filterTypeElement.value == 'blur') {
await createBlurGraph(context);
} else {
await createGrayscaleGraph(context);
}
performance.mark('build-end');

if (dispatchCheckbox.checked) {
createMLBuffers();
}

buildButton.disabled = true;
computeButton.disabled = false;

Expand All @@ -192,13 +202,32 @@
}
}

function createMLBuffers() {
inputBuffer = context.createBuffer(inputDescriptor);
context.writeBuffer(inputBuffer, inputs.input);

outputBuffer = context.createBuffer(outputDescriptor);
}

function destroyMLBuffers() {
if (inputBuffer) {
inputBuffer.destroy();
inputBuffer = null;
}
if (outputBuffer) {
outputBuffer.destroy();
outputBuffer = null;
}
}

function loadInput() {
const image = new Image();
image.onload = () => {
const inputCanvas = document.getElementById('input');
const inputCtx = inputCanvas.getContext('2d');
inputCtx.drawImage(image, 0, 0);
inputData = inputCtx.getImageData(0, 0, image.width, image.height);
const inputData = inputCtx.getImageData(0, 0, image.width, image.height);
inputs.input = imageDataToTensor(inputData);

if (!('ml' in navigator)) {
statusSpan.textContent = 'WebNN is not supported in your browser.';
Expand All @@ -215,6 +244,7 @@
}

function contextOptionsChanged() {
destroyMLBuffers();
buildButton.disabled = false;
context = null;
graphOptionsChanged();
Expand All @@ -241,6 +271,17 @@
const inputLayoutElement = document.getElementById('inputLayout');
inputLayoutElement.onchange = graphOptionsChanged

const dispatchCheckbox = document.getElementById('dispatch');
dispatchCheckbox.onchange = () => {
if (dispatchCheckbox.checked) {
if (context) {
createMLBuffers();
}
} else {
destroyMLBuffers();
}
};

const buildButton = document.getElementById('build');
buildButton.onclick = buildGraph;

Expand Down

0 comments on commit 971ab5f

Please sign in to comment.