From a3833a5e79784514ee1dc2b2dcd01b83c31aa7e3 Mon Sep 17 00:00:00 2001 From: xhcao Date: Thu, 2 Jan 2025 07:58:54 +0800 Subject: [PATCH] [js/webgpu] validate transpose perm if specified (#23197) ### Description ### Motivation and Context --- js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 5059645211aea..a348c1e637f4c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -13,14 +13,18 @@ export interface TransposeAttributes extends AttributeWithCacheKey { readonly perm: number[]; } -const validateInputs = (inputs: readonly TensorView[]): void => { +const validateInputs = (inputs: readonly TensorView[], perm: readonly number[]): void => { if (!inputs || inputs.length !== 1) { throw new Error('Transpose requires 1 input.'); } + + if (perm.length !== 0 && perm.length !== inputs[0].dims.length) { + throw new Error(`perm size ${perm.length} does not match input rank ${inputs[0].dims.length}`); + } }; const getAdjustedPerm = (inputRank: number, perm: number[]): number[] => - perm && perm.length !== inputRank ? [...new Array(inputRank).keys()].reverse() : perm; + perm.length !== 0 ? perm : [...new Array(inputRank).keys()].reverse(); const getOutputShape = (inputShape: readonly number[], perm: number[]): readonly number[] => ShapeUtil.sortBasedOnPerm(inputShape, getAdjustedPerm(inputShape.length, perm)); @@ -191,7 +195,7 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu }; export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => { - validateInputs(context.inputs); + validateInputs(context.inputs, attributes.perm); context.compute(createTransposeProgramInfo(context.inputs[0], attributes.perm)); };