Skip to content

Commit

Permalink
Simplify things by just processing all 4 channels
Browse files Browse the repository at this point in the history
  • Loading branch information
reillyeon committed Apr 30, 2024
1 parent da01d95 commit 285a2dd
Showing 1 changed file with 7 additions and 14 deletions.
21 changes: 7 additions & 14 deletions webnn-conv2d.html
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<canvas id="input" width="500" height="500"></canvas>
<canvas id="output" width="500" height="500"></canvas>
<script>
const channels = 3;
const channels = 4;

async function createGraph(context) {
const builder = new MLGraphBuilder(context);
Expand All @@ -32,24 +32,17 @@
}

function imageDataToTensor(imageData) {
const tensor = new Float32Array(imageData.width * imageData.height * 3);
for (let srcOffset = 0; srcOffset < imageData.data.length; srcOffset += 4) { // RGBA
const dstOffset = (srcOffset / 4) * 3; // RGB
tensor[dstOffset] = imageData.data[srcOffset] / 256; // R
tensor[dstOffset + 1] = imageData.data[srcOffset + 1] / 256; // G
tensor[dstOffset + 2] = imageData.data[srcOffset + 2] / 256; // B
const tensor = new Float32Array(imageData.data.length);
for (let i = 0; i < imageData.data.length; ++i) {
tensor[i] = imageData.data[i] / 256;
}
return tensor;
}

function tensorToImageData(tensor, width, height) {
const imageData = new ImageData(width, height);
for (let dstOffset = 0; dstOffset < imageData.data.length; dstOffset += 4) { // RGBA
const srcOffset = (dstOffset / 4) * 3; // RGB
imageData.data[dstOffset] = tensor[srcOffset] * 256; // R
imageData.data[dstOffset + 1] = tensor[srcOffset + 1] * 256; // G
imageData.data[dstOffset + 2] = tensor[srcOffset + 2] * 256; // B
imageData.data[dstOffset + 3] = 255; // A
for (let i = 0; i < tensor.length; ++i) {
imageData.data[i] = tensor[i] * 256;
}
return imageData;
}
Expand All @@ -59,7 +52,7 @@
const {graph, outputWidth, outputHeight} = await createGraph(context);

const input = imageDataToTensor(inputData);
const output = new Float32Array(outputWidth * outputHeight * 3);
const output = new Float32Array(outputWidth * outputHeight * channels);

const {inputs, outputs} = await context.compute(graph, {input}, {output});

Expand Down

0 comments on commit 285a2dd

Please sign in to comment.