Skip to content

Commit

Permalink
Add a grayscale filter
Browse files Browse the repository at this point in the history
  • Loading branch information
reillyeon committed May 3, 2024
1 parent 10e7725 commit 1fad606
Showing 1 changed file with 62 additions and 9 deletions.
71 changes: 62 additions & 9 deletions webnn-conv2d.html
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,15 @@
</select>
</p>
<p>
<label for="filterSize">Filter size:</label>
<input id="filterSize" type="number" min="1" max="50" value="5" disabled>
<label for="filterType">Filter type:</label>
<select id="filterType" disabled>
<option selected value="blur">Blur</option>
<option value="grayscale">Grayscale</option>
</select>
</p>
<p>
<label for="blurRadius">Blur radius:</label>
<input id="blurRadius" type="number" min="1" max="50" value="2" disabled>
</p>
<p>
<label for="inputLayout">Input layout:</label>
Expand All @@ -34,7 +41,7 @@
const channels = 4;
let inputData;

async function createGraph(context) {
async function createBlurGraph(context) {
const inputLayout = inputLayoutElement.value;
const builder = new MLGraphBuilder(context);

Expand All @@ -45,8 +52,8 @@
}

// Right now Chromium only supports one filter layout for each input layout.
const filterHeight = Number(filterSizeElement.value);
const filterWidth = Number(filterSizeElement.value);
const filterHeight = Number(blurRadiusElement.value) * 2 + 1;
const filterWidth = Number(blurRadiusElement.value) * 2 + 1;
const filterLayout = inputLayout == 'nchw' ? 'oihw' : 'ihwo';
const filterShape =
filterLayout == 'oihw' ?
Expand Down Expand Up @@ -75,6 +82,43 @@
}
}

async function createGrayscaleGraph(context) {
const inputLayout = inputLayoutElement.value;
const builder = new MLGraphBuilder(context);

let input = builder.input(
'input', {dataType: 'float32', dimensions: [1, 500, 500, channels]});
if (inputLayout == 'nchw') {
input = builder.transpose(input, {permutation: [0, 3, 1, 2]})
}

// Right now Chromium only supports one filter layout for each input layout.
const filterLayout = inputLayout == 'nchw' ? 'oihw' : 'ohwi';
const filterShape =
filterLayout == 'oihw' ?
[channels, channels, 1, 1] : [channels, 1, 1, channels]

// Mix the RGB channels but not the alpha channel.
const filterData = Float32Array.of(
1/3, 1/3, 1/3, 0,
1/3, 1/3, 1/3, 0,
1/3, 1/3, 1/3, 0,
0, 0, 0, 1);
const filter = builder.constant(
{dataType: 'float32', dimensions: filterShape}, filterData);

let output = builder.conv2d(input, filter, {inputLayout, filterLayout});
if (inputLayout == 'nchw') {
output = builder.transpose(output, {permutation: [0, 2, 3, 1]})
}

return {
graph: await builder.build({'output': output}),
outputHeight: output.shape()[1],
outputWidth: output.shape()[2],
}
}

function imageDataToTensor(imageData) {
const tensor = new Float32Array(imageData.data.length);
for (let i = 0; i < imageData.data.length; ++i) {
Expand Down Expand Up @@ -104,7 +148,9 @@
const context = await navigator.ml.createContext({deviceType: deviceTypeElement.value});

const buildStart = performance.now();
const {graph, outputWidth, outputHeight} = await createGraph(context);
const {graph, outputWidth, outputHeight} =
filterTypeElement.value == 'blur' ?
await createBlurGraph(context) : await createGrayscaleGraph(context);
const buildEnd = performance.now();

const input = imageDataToTensor(inputData);
Expand Down Expand Up @@ -137,7 +183,8 @@
}

deviceTypeElement.disabled = false;
filterSizeElement.disabled = false;
filterTypeElement.disabled = false;
blurRadiusElement.disabled = false;
inputLayoutElement.disabled = false;
run();
};
Expand All @@ -147,8 +194,14 @@
const deviceTypeElement = document.getElementById('deviceType');
deviceTypeElement.onchange = run

const filterSizeElement = document.getElementById('filterSize');
filterSizeElement.onchange = run
const filterTypeElement = document.getElementById('filterType');
filterTypeElement.onchange = () => {
blurRadiusElement.disabled = filterTypeElement.value != 'blur';
run();
};

const blurRadiusElement = document.getElementById('blurRadius');
blurRadiusElement.onchange = run

const inputLayoutElement = document.getElementById('inputLayout');
inputLayoutElement.onchange = run
Expand Down

0 comments on commit 1fad606

Please sign in to comment.