Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Optimize ConvTranspose (Continue) #23429

Merged
merged 3 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 90 additions & 24 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ export const createConvTranspose2DProgramInfo = (
const inputChannelsPerGroup = wShape[2] / group;
const outputChannelsPerGroup = wShape[3];
const aComponents = isChannelsLast ? getMaxComponents(inputChannelsPerGroup) : 1;
const packInputAs4 = isChannelsLast && outputChannelsPerGroup === 1;
const inputChannelsPerGroupInt = packInputAs4
? Math.floor(inputChannelsPerGroup / 4) * 4
: Math.floor(inputChannelsPerGroup / aComponents) * aComponents;
const inputChannelsRemainder = inputChannelsPerGroup - inputChannelsPerGroupInt;
const components = isChannelsLast ? getMaxComponents(outputChannelsPerGroup) : 1;
const bComponents = isChannelsLast ? (outputChannelsPerGroup === 1 ? aComponents : components) : 1;
const outputSize = ShapeUtil.size(outputShape) / components;
Expand Down Expand Up @@ -78,7 +83,7 @@ export const createConvTranspose2DProgramInfo = (
{ type: DataType.uint32, data: dilations },
{ type: DataType.uint32, data: effectiveFilterDims },
{ type: DataType.int32, data: pads },
{ type: DataType.uint32, data: inputChannelsPerGroup },
{ type: DataType.uint32, data: inputChannelsPerGroupInt },
{ type: DataType.uint32, data: outputChannelsPerGroup },
...createTensorShapeVariables(inputs[0].dims, inputs[1].dims),
];
Expand Down Expand Up @@ -114,16 +119,40 @@ export const createConvTranspose2DProgramInfo = (

const calculateResult = (): string => {
let calcStr = '';
if (aComponents === 1) {
calcStr += `
let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)};
let wValue = ${w.getByOffset(`w_offset / ${bComponents}`)};
dotProd = dotProd + xValue * wValue;`;
if (packInputAs4) {
if (aComponents === 4) {
calcStr += `
let xValue = ${dy.getByOffset('x_offset')};
let wValue = ${w.getByOffset('w_offset')};
dotProd = dotProd + dot(xValue, wValue);
x_offset += 1u;
w_offset += 1u;`;
} else if (aComponents === 2) {
calcStr += `
dotProd = dotProd + dot(vec4<${dataType}>(${dy.getByOffset('x_offset')}, ${dy.getByOffset('x_offset + 1u')}), vec4<${dataType}>(${w.getByOffset('w_offset')}, ${w.getByOffset('w_offset + 1u')}));
x_offset += 2u;
w_offset += 2u;`;
} else if (aComponents === 1) {
calcStr += `
dotProd = dotProd + dot(vec4<${dataType}>(${dy.getByOffset('x_offset')}, ${dy.getByOffset('x_offset + 1u')}, ${dy.getByOffset('x_offset + 2u')}, ${dy.getByOffset('x_offset + 3u')}), vec4<${dataType}>(${w.getByOffset('w_offset')}, ${w.getByOffset('w_offset + 1u')}, ${w.getByOffset('w_offset + 2u')}, ${w.getByOffset('w_offset + 3u')}));
x_offset += 4u;
w_offset += 4u;`;
}
} else {
if (outputChannelsPerGroup === 1) {
calcStr += `
let xValue = ${
isChannelsLast
? dy.getByOffset(
`${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}`,
)
: dy.get('batch', 'inputChannel', 'idyR', 'idyC')
};
`;
if (aComponents === 1) {
calcStr += `
let wValue = ${w.getByOffset(`${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)} / ${bComponents}`)};
dotProd = dotProd + dot(xValue, wValue);`;
let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)};
let wValue = ${w.getByOffset(`w_offset / ${bComponents}`)};
dotProd = dotProd + xValue * wValue;`;
} else {
for (let c = 0; c < aComponents; c++) {
calcStr += `
Expand All @@ -134,6 +163,32 @@ export const createConvTranspose2DProgramInfo = (
}
return calcStr;
};
const calculateRemainder = (): string => {
if (inputChannelsRemainder === 0) {
return '';
}
if (!packInputAs4) {
throw new Error(`packInputAs4 ${packInputAs4} is not true.`);
}
let calcStr = '';
if (aComponents === 1) {
calcStr += 'dotProd = dotProd';
for (let i = 0; i < inputChannelsRemainder; i++) {
calcStr += `
+ ${dy.getByOffset(`x_offset + ${i}`)} * ${w.getByOffset(`w_offset + ${i}`)}`;
}
calcStr += ';';
} else if (aComponents === 2) {
if (inputChannelsRemainder !== 2) {
throw new Error(`Invalid inputChannelsRemainder ${inputChannelsRemainder}.`);
}
calcStr += `
let xValue = ${dy.getByOffset('x_offset')};
let wValue = ${w.getByOffset('w_offset')};
dotProd = dotProd + dot(xValue, wValue);`;
}
return calcStr;
};
const codeSnippet = `
let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)};
let batch = ${output.indicesGet('outputIndices', 0)};
Expand All @@ -148,7 +203,12 @@ export const createConvTranspose2DProgramInfo = (
// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
// ? = to be determined. : = across all values in that axis.
var dotProd = ${output.type.value}(0.0);
for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) {
var wR: u32 = 0;
if (uniforms.dilations.x == 1) {
// Minimum wR >= 0 that satisfies (dyRCorner + wR) % (uniforms.strides.x) == 0
wR = u32(((dyRCorner + i32(uniforms.strides.x) - 1) / i32(uniforms.strides.x)) * i32(uniforms.strides.x) - dyRCorner);
}
for (; wR < uniforms.effective_filter_dims.x; wR = wR + 1) {
if (wR % uniforms.dilations.x != 0) {
continue;
}
Expand All @@ -158,10 +218,13 @@ export const createConvTranspose2DProgramInfo = (
wRPerm < 0) {
continue;
}
wR = wR + uniforms.strides[0] - 1;
let idyR: u32 = u32(dyR);

for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {
var wC: u32 = 0;
if (uniforms.dilations.y == 1) {
// Minimum wC >= 0 that satisfies (dyCCorner + wC) % (uniforms.strides.y) == 0
wC = u32(((dyCCorner + i32(uniforms.strides.y) - 1) / i32(uniforms.strides.y)) * i32(uniforms.strides.y) - dyCCorner);
}
for (; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {
if (wC % uniforms.dilations.y != 0) {
continue;
}
Expand All @@ -171,21 +234,24 @@ export const createConvTranspose2DProgramInfo = (
fract(dyC) > 0.0 || wCPerm < 0) {
continue;
}
wC = wC + uniforms.strides.y - 1;
let idyC: u32 = u32(dyC);
var inputChannel = groupId * uniforms.input_channels_per_group;
for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + ${aComponents}) {
let xValue = ${
isChannelsLast
? dy.getByOffset(
`${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}`,
)
: dy.get('batch', 'inputChannel', 'idyR', 'idyC')
};
${
packInputAs4
? `
var x_offset = ${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents};
var w_offset = ${w.indicesToOffset(`${w.type.indices}(wRPerm, wCPerm, inputChannel, wOutChannel)`)} / ${bComponents};
`
: ''
}
for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + ${packInputAs4 ? 4 : aComponents}) {
${calculateResult()}
inputChannel = inputChannel + ${aComponents};
inputChannel = inputChannel + ${packInputAs4 ? 4 : aComponents};
}
${calculateRemainder()}
wC = wC + uniforms.strides.y - 1;
}
wR = wR + uniforms.strides[0] - 1;
}
let value = dotProd${hasBias ? ` + bias[d1 / ${components}]` : ''};
${output.setByOffset('global_idx', 'value')};
Expand All @@ -201,7 +267,7 @@ export const createConvTranspose2DProgramInfo = (
return {
name: 'ConvTranspose2D',
shaderCache: {
hint: `${attributes.cacheKey};${aComponents}${bComponents}${components}${outputChannelsPerGroup === 1}`,
hint: `${attributes.cacheKey};${aComponents}${bComponents}${components}${outputChannelsPerGroup === 1}${inputChannelsRemainder}`,
inputDependencies,
},
getRunData: () => ({
Expand Down
146 changes: 146 additions & 0 deletions js/web/test/data/ops/conv-transpose.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,152 @@
}
]
},
{
"name": "ConvTranspose with output channels = 1",
"operator": "ConvTranspose",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
{ "name": "strides", "data": [2, 2], "type": "ints" }
],
"cases": [
{
"name": "inChannels = 5",
"inputs": [
{
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45
],
"dims": [1, 5, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8],
"dims": [5, 1, 2, 2],
"type": "float32"
},
{
"data": [2],
"dims": [1],
"type": "float32"
}
],
"outputs": [
{
"data": [
437, 532, 458, 558, 479, 584, 627, 722, 658, 758, 689, 794, 500, 610, 521, 636, 542, 662, 720, 830, 751,
866, 782, 902, 563, 688, 584, 714, 605, 740, 813, 938, 844, 974, 875, 1010
],
"dims": [1, 1, 6, 6],
"type": "float32"
}
]
},
{
"name": "inChannels = 6",
"inputs": [
{
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 1, 2, 3, 4, 5, 6, 7, 8, 9
],
"dims": [1, 6, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4],
"dims": [6, 1, 2, 2],
"type": "float32"
},
{
"data": [2],
"dims": [1],
"type": "float32"
}
],
"outputs": [
{
"data": [
438, 534, 460, 562, 482, 590, 630, 726, 664, 766, 698, 806, 504, 618, 526, 646, 548, 674, 732, 846, 766,
886, 800, 926, 570, 702, 592, 730, 614, 758, 834, 966, 868, 1006, 902, 1046
],
"dims": [1, 1, 6, 6],
"type": "float32"
}
]
},
{
"name": "inChannels = 7",
"inputs": [
{
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18
],
"dims": [1, 7, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8],
"dims": [7, 1, 2, 2],
"type": "float32"
},
{
"data": [2],
"dims": [1],
"type": "float32"
}
],
"outputs": [
{
"data": [
488, 594, 515, 628, 542, 662, 700, 806, 741, 854, 782, 902, 569, 696, 596, 730, 623, 764, 823, 950, 864,
998, 905, 1046, 650, 798, 677, 832, 704, 866, 946, 1094, 987, 1142, 1028, 1190
],
"dims": [1, 1, 6, 6],
"type": "float32"
}
]
},
{
"name": "inChannels = 8",
"inputs": [
{
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 1, 2, 3, 4, 5, 6, 7, 8, 9
],
"dims": [1, 8, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4],
"dims": [8, 1, 2, 2],
"type": "float32"
},
{
"data": [2],
"dims": [1],
"type": "float32"
}
],
"outputs": [
{
"data": [
489, 596, 517, 632, 545, 668, 703, 810, 747, 862, 791, 914, 573, 704, 601, 740, 629, 776, 835, 966, 879,
1018, 923, 1070, 657, 812, 685, 848, 713, 884, 967, 1122, 1011, 1174, 1055, 1226
],
"dims": [1, 1, 6, 6],
"type": "float32"
}
]
}
]
},
{
"name": "ConvTranspose without bias addition C",
"operator": "ConvTranspose",
Expand Down
Loading