Skip to content

Commit

Permalink
Add option to run WebNN add demo with MLBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
reillyeon committed Jul 13, 2024
1 parent 8224943 commit 8acd727
Showing 1 changed file with 69 additions and 4 deletions.
73 changes: 69 additions & 4 deletions webnn-add.html
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
<option value="high-performance">High performance</option>
</select>
</p>
<p>
<label for="dispatch">Use <code>MLBuffer</code>:</label>
<input id="dispatch" type="checkbox">
</p>
<p>
<button id="build">Build</button> <button disabled id="compute">Compute</button>
</p>
Expand All @@ -35,14 +39,60 @@
const cOutput = document.getElementById('c');
const deviceOption = document.getElementById('device');
const powerOption = document.getElementById('power');
const threadsOption = document.getElementById('threads');
const dispatchCheckbox = document.getElementById('dispatch');
const buildButton = document.getElementById('build');
const computeButton = document.getElementById('compute');
const outputSpan = document.getElementById('output');

const operandType = {dataType: 'float32', dimensions: [1]};
const inputMLBuffers = {'a': null, 'b': null};
const outputMLBuffers = {'c': null};

let context;
let graph;

function createMLBuffers() {
inputMLBuffers.a = context.createBuffer(operandType);
inputMLBuffers.b = context.createBuffer(operandType);
outputMLBuffers.c = context.createBuffer(operandType);
}

function destroyMLBuffers() {
if (inputMLBuffers.a) {
inputMLBuffers.a.destroy();
inputMLBuffers.a = null;
}
if (inputMLBuffers.b) {
inputMLBuffers.b.destroy();
inputMLBuffers.b = null;
}
if (outputMLBuffers.c) {
outputMLBuffers.c.destroy();
outputMLBuffers.c = null;
}
}

function contextOptionsChanged() {
destroyMLBuffers();
buildButton.disabled = false;
computeButton.disabled = true;
context = null;
graph = null;
}

deviceOption.addEventListener('change', contextOptionsChanged);
powerOption.addEventListener('change', contextOptionsChanged);

dispatchCheckbox.addEventListener('change', () => {
if (dispatchCheckbox.checked) {
if (context) {
createMLBuffers();
}
} else {
destroyMLBuffers();
}
});

buildButton.addEventListener('click', async () => {
try {
computeButton.disabled = true;
Expand All @@ -54,15 +104,19 @@
console.log(options);
context = await navigator.ml.createContext(options);

if (dispatchCheckbox.checked) {
createMLBuffers();
}

const builder = new MLGraphBuilder(context);
const operandType = {dataType: 'float32', dimensions: [1]};
const a = builder.input('a', operandType);
const b = builder.input('b', operandType);
const c = builder.add(a, b);
graph = await builder.build({'c': c});

outputSpan.textContent = 'Graph ready!'
computeButton.disabled = false;
buildButton.disabled = true;
} catch (e) {
outputSpan.textContent = `${e.name}: ${e.message}`;
}
Expand All @@ -78,10 +132,21 @@

inputs.a[0] = aInput.value;
inputs.b[0] = bInput.value;
({inputs, outputs} = await context.compute(graph, inputs, outputs));

outputSpan.textContent = 'Compute finished!';
if (dispatchCheckbox.checked) {
context.writeBuffer(inputMLBuffers.a, inputs.a);
context.writeBuffer(inputMLBuffers.b, inputs.b);

context.dispatch(graph, inputMLBuffers, outputMLBuffers);

const buffer = await context.readBuffer(outputMLBuffers.c);
outputs.c = new Float32Array(buffer);
} else {
({inputs, outputs} = await context.compute(graph, inputs, outputs));
}

cOutput.textContent = outputs.c[0];
outputSpan.textContent = 'Compute finished!';
} catch (e) {
outputSpan.textContent = `${e.name}: ${e.message}`;
}
Expand Down

0 comments on commit 8acd727

Please sign in to comment.