Skip to content

Commit

Permalink
Add build and compute buttons to the conv2d demo
Browse files Browse the repository at this point in the history
  • Loading branch information
reillyeon committed Jul 13, 2024
1 parent 8acd727 commit 591e88b
Showing 1 changed file with 59 additions and 18 deletions.
77 changes: 59 additions & 18 deletions webnn-conv2d.html
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,19 @@
<td><canvas id="output" width="500" height="500"></canvas></td>
</tr>
</table>
<p>
<button id="build" disabled>Build</button> <button disabled id="compute">Compute</button>
</p>
<pre id="status"></pre>
<script>
const channels = 4;
let inputData;

let context;
let graph;
let outputWidth;
let outputHeight;

async function createBlurGraph(context) {
const inputLayout = inputLayoutElement.value;
const builder = new MLGraphBuilder(context);
Expand Down Expand Up @@ -139,31 +147,46 @@
const outputCtx = outputCanvas.getContext('2d');
const statusSpan = document.getElementById('status');

async function run() {
async function computeGraph() {
statusSpan.textContent = '';

try {
outputCtx.clearRect(0, 0, outputCanvas.width, outputCanvas.height);

const context = await navigator.ml.createContext({deviceType: deviceTypeElement.value});

const buildStart = performance.now();
const {graph, outputWidth, outputHeight} =
filterTypeElement.value == 'blur' ?
await createBlurGraph(context) : await createGrayscaleGraph(context);
const buildEnd = performance.now();

const input = imageDataToTensor(inputData);
const output = new Float32Array(outputWidth * outputHeight * channels);

const computeStart = performance.now();
performance.mark('compute-start');
const {inputs, outputs} = await context.compute(graph, {input}, {output});
const computeEnd = performance.now();
performance.mark('compute-end');

const outputData = tensorToImageData(outputs.output, outputWidth, outputHeight);
outputCtx.putImageData(outputData, 0, 0);

statusSpan.textContent = `Build took ${(buildEnd - buildStart).toFixed(1)}ms. Compute took ${(computeEnd - computeStart).toFixed(1)}ms.`;
const computeMeasure = performance.measure('compute-measure', 'compute-start', 'compute-end');
statusSpan.textContent = `Compute took ${(computeMeasure.duration).toFixed(1)}ms.`;
} catch (e) {
statusSpan.textContent = e.stack;
}
}

async function buildGraph() {
statusSpan.textContent = '';

try {
context = await navigator.ml.createContext({deviceType: deviceTypeElement.value});

performance.mark('build-start');
({graph, outputWidth, outputHeight} =
filterTypeElement.value == 'blur' ?
await createBlurGraph(context) : await createGrayscaleGraph(context));
performance.mark('build-end');

buildButton.disabled = true;
computeButton.disabled = false;

const buildMeasure = performance.measure('build-duration', 'build-start', 'build-end');
statusSpan.textContent = `Build took ${(buildMeasure.duration).toFixed(1)}ms.`;
} catch (e) {
statusSpan.textContent = e.stack;
}
Expand All @@ -178,33 +201,51 @@
inputData = inputCtx.getImageData(0, 0, image.width, image.height);

if (!('ml' in navigator)) {
statusSpan.textContent = 'WebNN not supported in your browser.';
statusSpan.textContent = 'WebNN is not supported in your browser.';
return;
}

deviceTypeElement.disabled = false;
filterTypeElement.disabled = false;
blurRadiusElement.disabled = false;
inputLayoutElement.disabled = false;
run();
buildButton.disabled = false;
};
image.src = 'photo.jpg';
}

function contextOptionsChanged() {
buildButton.disabled = false;
context = null;
graphOptionsChanged();
}

function graphOptionsChanged() {
buildButton.disabled = false;
computeButton.disabled = true;
graph = null;
}

const deviceTypeElement = document.getElementById('deviceType');
deviceTypeElement.onchange = run
deviceTypeElement.onchange = contextOptionsChanged

const filterTypeElement = document.getElementById('filterType');
filterTypeElement.onchange = () => {
blurRadiusElement.disabled = filterTypeElement.value != 'blur';
run();
graphOptionsChanged();
};

const blurRadiusElement = document.getElementById('blurRadius');
blurRadiusElement.onchange = run
blurRadiusElement.onchange = graphOptionsChanged

const inputLayoutElement = document.getElementById('inputLayout');
inputLayoutElement.onchange = run
inputLayoutElement.onchange = graphOptionsChanged

const buildButton = document.getElementById('build');
buildButton.onclick = buildGraph;

const computeButton = document.getElementById('compute');
computeButton.onclick = computeGraph;

loadInput();
</script>
Expand Down

0 comments on commit 591e88b

Please sign in to comment.