From 2172e484b8750ccaeda08ff9d6257c56c1769413 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 25 Nov 2020 15:51:19 +0530 Subject: [PATCH 01/72] Add support for float16 tensor types. --- Athos/TFCompiler/Graph.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/Athos/TFCompiler/Graph.py b/Athos/TFCompiler/Graph.py index 3d9a7b3c..014a030c 100644 --- a/Athos/TFCompiler/Graph.py +++ b/Athos/TFCompiler/Graph.py @@ -36,12 +36,14 @@ def errIfTokensNotMinLen(tokens, minlen, lineNum, entity): class DataTypeEnum(enum.Enum): DT_INVALID = 0 DT_FLOAT = 1 - DT_BOOL = 2 - DT_INT32 = 3 - DT_INT64 = 4 + DT_FLOAT16 = 2 + DT_BOOL = 3 + DT_INT32 = 4 + DT_INT64 = 5 def Parse(str): if (str == "DT_FLOAT"): return DataTypeEnum.DT_FLOAT + elif (str == "DT_HALF"): return DataTypeEnum.DT_FLOAT16 elif (str == "DT_BOOL"): return DataTypeEnum.DT_BOOL elif (str == "DT_INT32"): return DataTypeEnum.DT_INT32 elif (str == "DT_INT64"): return DataTypeEnum.DT_INT64 @@ -51,6 +53,7 @@ def Parse(str): def Size(dt): if (dt == DataTypeEnum.DT_INVALID): return 0 elif (dt == DataTypeEnum.DT_FLOAT): return 4 + elif (dt == DataTypeEnum.DT_FLOAT16): return 2 elif (dt == DataTypeEnum.DT_BOOL): return 1 elif (dt == DataTypeEnum.DT_INT32): return 4 elif (dt == DataTypeEnum.DT_INT64): return 8 @@ -191,6 +194,8 @@ def __convToBytes(self): self.__valArr = [self.__valInput]*numElements elif ((self.__dtype == DataTypeEnum.DT_FLOAT) and (self.__valInput is not None)): self.__valArr = [self.__valInput]*numElements + elif ((self.__dtype == DataTypeEnum.DT_FLOAT16) and (self.__valInput is not None)): + self.__valArr = [self.__valInput]*numElements elif ((self.__dtype == DataTypeEnum.DT_INT32 or self.__dtype == DataTypeEnum.DT_INT64) and (self.__valInput is not None)): self.__valArr = [self.__valInput]*numElements @@ -310,6 +315,8 @@ def getContentAsValArr(self): # self.__valArr = returnArr if self.__dtype == DataTypeEnum.DT_FLOAT: dtype = numpy.dtype(' "dtype" So we can directly refer to the attributes without adding double quotes to them. --- Athos/TFCompiler/Graph.py | 10 ++-- Athos/TFCompiler/ProcessTFGraph.py | 14 +++--- Athos/TFCompiler/TFNodesAST.py | 74 +++++++++++++++--------------- 3 files changed, 49 insertions(+), 49 deletions(-) diff --git a/Athos/TFCompiler/Graph.py b/Athos/TFCompiler/Graph.py index 014a030c..078a3bb5 100644 --- a/Athos/TFCompiler/Graph.py +++ b/Athos/TFCompiler/Graph.py @@ -520,7 +520,7 @@ def getAttrMapRef(self): return self.__attr def getAttrVal(self, attrName): - qName = '"' + attrName + '"' + qName = attrName if not qName in self.__attr: return None return self.__attr[qName] @@ -541,7 +541,7 @@ def readAttrFromFilePointer(self, fileP, cnt): #keyStr is already non-None .. there is then probably some error print("Too many keys found while parsing attr for node at line =", cnt, file=sys.stderr) return (False, cnt) - keyStr = tokens[1] + keyStr = tokens[1][1:-1] elif (curToken == "value"): curVal = Value() (noParseError, cnt) = curVal.readFromFilePointer(fileP, cnt) @@ -570,13 +570,13 @@ def readFromFilePointer(self, fileP, cnt): return (True, cnt) elif (curToken == "name:"): if (errIfTokensNotMinLen(tokens, 2, cnt, "node")): return (False, cnt) - self.__name = tokens[1] + self.__name = tokens[1][1:-1] elif (curToken == "op:"): if (errIfTokensNotMinLen(tokens, 2, cnt, "node")): return (False, cnt) - self.__op = tokens[1] + self.__op = tokens[1][1:-1] elif (curToken == "input:"): if (errIfTokensNotMinLen(tokens, 2, cnt, "node")): return (False, cnt) - self.__inputs.append(tokens[1]) + self.__inputs.append(tokens[1][1:-1]) elif (curToken == "attr"): (noParseError, cnt) = self.readAttrFromFilePointer(fileP, cnt) if (not(noParseError)): diff --git a/Athos/TFCompiler/ProcessTFGraph.py b/Athos/TFCompiler/ProcessTFGraph.py index 7955a521..d2b409b6 100644 --- a/Athos/TFCompiler/ProcessTFGraph.py +++ b/Athos/TFCompiler/ProcessTFGraph.py @@ -36,7 +36,7 @@ def checkTFNodeNameForEq(curNodeOp:str, givenOp:str): def generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict): curNodeOp = curNode.getOp() ast = None - func = getattr(TFNodesAST, curNodeOp[1:-1]) #To remove the " at the begin and end + func = getattr(TFNodesAST, curNodeOp) (assignedVarAST, curAST) = func(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict) return (assignedVarAST, curAST) @@ -53,8 +53,8 @@ def generateIRCode(graph, extraInfoDict): assert(curInp in dictNodeNameToOutVarStr) #Consequence of topological sorting of the TF graph (assignedVarAST, curAst) = generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraInfoDict) - mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : curNode.getOp()[1:-1], - AST.ASTNode.mtdKeyTFNodeName : curNode.getName()[1:-1]} + mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : curNode.getOp(), + AST.ASTNode.mtdKeyTFNodeName : curNode.getName()} if (curAst is None): dictNodeNameToOutVarStr[curNode.getName()] = None @@ -104,7 +104,7 @@ def prefixAllPlaceHolderNodes(graph): placeHolderNodes = [] remNodes = [] for curNode in allNodes: - if (curNode.getOp() == "\"Placeholder\"" or curNode.getOp() == "\"VariableV2\""): + if (curNode.getOp() == "Placeholder" or curNode.getOp() == "VariableV2"): # Assert this is indeed a leaf node assert(len(curNode.getInputsRef()) == 0) placeHolderNodes.append(curNode) @@ -138,19 +138,19 @@ def main(): for curNode in graph.getAllNodesRef(): inputsRef = curNode.getInputsRef() for i,curInput in enumerate(inputsRef): - if (curInput.startswith('"^')): + if (curInput.startswith('^')): # My hypothesis from empirical observation is that inputs which have '^' ahead of the node name # denote control flow dependency and not data dependency. # For all purposes for this compilation, control and data dependency is considered same. # The reasoning being that everything is serial -- and graph execution is done in a # a topological sort. - inputsRef[i] = '"' + curInput.split('^')[-1] + inputsRef[i] = curInput.split('^')[-1] # Create extra info dict # Format : (sizeInfo) extraInfoDict = {} for k,v in sizeInfo.items(): - extraInfoDict["\"" + k + "\""] = (v,) + extraInfoDict[k] = (v,) for curNode in graph.getAllNodesRef(): if (curNode.getName() not in extraInfoDict): extraInfoDict[curNode.getName()] = (None,) diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index 0920c7ac..74cc02ac 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -68,17 +68,17 @@ def MatMul(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : attrMapRef = curNode.getAttrMapRef() transposeABool = transposeBBool = False - if ("\"transpose_a\"" in attrMapRef): - transposeABool = attrMapRef["\"transpose_a\""].getB() - if ("\"transpose_b\"" in attrMapRef): - transposeBBool = attrMapRef["\"transpose_b\""].getB() + if ("transpose_a" in attrMapRef): + transposeABool = attrMapRef["transpose_a"].getB() + if ("transpose_b" in attrMapRef): + transposeBBool = attrMapRef["transpose_b"].getB() if (transposeABool): inp1AST = AST.Transp(inp1AST) if (transposeBBool): inp2AST = AST.Transp(inp2AST) return (None, AST.BOp(inp1AST, TFNodesAST.getOperatorsIdx('*'), inp2AST)) def Placeholder(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] - curNodeInputType = curNode.getAttrMapRef()["\"dtype\""].getDataType() + curNodeInputType = curNode.getAttrMapRef()["dtype"].getDataType() assert(curNodeInputType is not Graph.DataTypeEnum.DT_INVALID) # NOTE: There has to be some way for Athos to differentiate model from image, since in the compiled code @@ -104,7 +104,7 @@ def Identity(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr inputsRef = curNode.getInputsRef() assert(len(inputsRef)==1) - curNodeDataType = curNode.getAttrMapRef()["\"T\""].getDataType() + curNodeDataType = curNode.getAttrMapRef()["T"].getDataType() assert(curNodeDataType is not Graph.DataTypeEnum.DT_INVALID) curNodeShape = extraNodeInfoDict[curNode.getName()][0] @@ -175,8 +175,8 @@ def FloorDiv(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr return (None, AST.Func(TFNodesAST.getOperatorsIdx('floor'), realDivAST)) def VariableV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - curNodeShapeLi = curNode.getAttrMapRef()["\"shape\""].getShape().getDimRef()[:] - curNodeInputType = curNode.getAttrMapRef()["\"dtype\""].getDataType() + curNodeShapeLi = curNode.getAttrMapRef()["shape"].getShape().getDimRef()[:] + curNodeInputType = curNode.getAttrMapRef()["dtype"].getDataType() # NOTE : since this becomes an input node right now, i have also added to be prefixed at top in ProcessTFGraph::prefixAllPlaceHolderNodes() # NOTE: There has to be some way for Athos to differentiate model from image, since in the compiled code @@ -188,8 +188,8 @@ def VariableV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarSt def Const(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): assert(len(curNode.getInputsRef()) == 0) - tensor = curNode.getAttrMapRef()["\"value\""].getTensor() - curNodeDataType = curNode.getAttrMapRef()["\"dtype\""].getDataType() + tensor = curNode.getAttrMapRef()["value"].getTensor() + curNodeDataType = curNode.getAttrMapRef()["dtype"].getDataType() curNodeShape = tensor.getShapeRef()[:] #create a different copy to not change the original copy tensorConstantVal = tensor.getConstantVal() @@ -235,8 +235,8 @@ def Shape(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : d def Cast(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 1) - sourceType = curNode.getAttrMapRef()["\"SrcT\""].getDataType() - destType = curNode.getAttrMapRef()["\"DstT\""].getDataType() + sourceType = curNode.getAttrMapRef()["SrcT"].getDataType() + destType = curNode.getAttrMapRef()["DstT"].getDataType() return (None, AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.Cast.name, [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), @@ -247,7 +247,7 @@ def Cast(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : di def ZerosLike(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef)==1) - curNodeOutputType = curNode.getAttrMapRef()["\"T\""].getDataType() + curNodeOutputType = curNode.getAttrMapRef()["T"].getDataType() assert(curNodeOutputType is not Graph.DataTypeEnum.DT_INVALID) retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.CreateTensor.name, @@ -261,7 +261,7 @@ def Fill(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : di curNodeOutputShape = extraNodeInfoDict[inputsRef[0]][0] assert(len(curNodeOutputShape) == 1) #inputsRef[0] denotes a shape and should have a rank of 1 - curNodeOutputType = curNode.getAttrMapRef()["\"T\""].getDataType() + curNodeOutputType = curNode.getAttrMapRef()["T"].getDataType() assert(curNodeOutputType is not Graph.DataTypeEnum.DT_INVALID) retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], @@ -280,7 +280,7 @@ def helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, img assert(FD) assert(strideD) zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1 - if (paddingUsedStr == "\"SAME\""): + if (paddingUsedStr == "SAME"): # Reference for following: # https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn totalPaddingH = totalPaddingW = totalPaddingD = 0 @@ -310,7 +310,7 @@ def helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, img zPadDLeft = totalPaddingD // 2 zPadDRight = totalPaddingD - zPadDLeft - elif (paddingUsedStr == "\"VALID\""): + elif (paddingUsedStr == "VALID"): zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = 0 else: zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1 @@ -324,7 +324,7 @@ def Conv2D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : inputsRef = curNode.getInputsRef() assert(len(inputsRef)==2) - stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi() + stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() assert(stridesUsed[0]==1 and stridesUsed[3]==1) strideH = stridesUsed[1] strideW = stridesUsed[2] @@ -337,7 +337,7 @@ def Conv2D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : FH = filterShape[0] FW = filterShape[1] - paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS() + paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, FH, FW, @@ -363,7 +363,7 @@ def Conv3D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : inputsRef = curNode.getInputsRef() assert(len(inputsRef)==2) - stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi() + stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() assert(stridesUsed[0]==1 and stridesUsed[4]==1) strideD = stridesUsed[1] strideH = stridesUsed[2] @@ -379,7 +379,7 @@ def Conv3D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : FH = filterShape[1] FW = filterShape[2] - paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS() + paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD, FD, strideD ) @@ -406,7 +406,7 @@ def Conv3DBackpropInputV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNam inputsRef = curNode.getInputsRef() assert(len(inputsRef)==3) #output_shape, filter, input - stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi() + stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() assert(stridesUsed[0]==1 and stridesUsed[4]==1) strideD = stridesUsed[1] strideH = stridesUsed[2] @@ -427,7 +427,7 @@ def Conv3DBackpropInputV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNam outputH = outputShape[2] outputW = outputShape[3] - paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS() + paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() # Important: Using outputH and outputW in the below is not an error! # For convTranspose, the parameters passed in the node are of the conv of which this convTranspose is an inverse. @@ -463,12 +463,12 @@ def helper_processPool(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameTo options = {} - stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi() + stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() assert((stridesUsed[0] == 1) and (stridesUsed[3] == 1)) strideH = stridesUsed[1] strideW = stridesUsed[2] - kSizeUsed = curNode.getAttrMapRef()["\"ksize\""].getList().getILi() + kSizeUsed = curNode.getAttrMapRef()["ksize"].getList().getILi() assert((kSizeUsed[0] == 1) and (kSizeUsed[3] == 1)) kSizeH = kSizeUsed[1] kSizeW = kSizeUsed[2] @@ -477,7 +477,7 @@ def helper_processPool(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameTo imgH = inputShape[1] imgW = inputShape[2] - paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS() + paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, kSizeH, kSizeW, strideH, strideW, @@ -512,7 +512,7 @@ def AvgPool(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : def ConcatV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() - N = curNode.getAttrMapRef()["\"N\""].getI() + N = curNode.getAttrMapRef()["N"].getI() assert(len(inputsRef) == N+1) #One extra for axis #TODO : Since the axis of concat is constant, therefore, its known here - the input's sizes along that dim should be # passed as input to the below function. @@ -535,7 +535,7 @@ def ExpandDims(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarSt def Slice(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 3) - curNodeDataType = curNode.getAttrMapRef()["\"T\""].getDataType() + curNodeDataType = curNode.getAttrMapRef()["T"].getDataType() retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.CreateCopy.name, [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), # of this @@ -556,8 +556,8 @@ def Sum(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dic attrMapRef = curNode.getAttrMapRef() assert(len(inputsRef) == 2) keepdims = False - if ("\"keep_dims\"" in attrMapRef): - keepdims = attrMapRef["\"keep_dims\""].getB() + if ("keep_dims" in attrMapRef): + keepdims = attrMapRef["keep_dims"].getB() curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] return (None, AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), @@ -570,8 +570,8 @@ def Mean(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : di attrMapRef = curNode.getAttrMapRef() assert(len(inputsRef) == 2) keepdims = False - if ("\"keep_dims\"" in attrMapRef): - keepdims = attrMapRef["\"keep_dims\""].getB() + if ("keep_dims" in attrMapRef): + keepdims = attrMapRef["keep_dims"].getB() curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] return (None, AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), @@ -601,12 +601,12 @@ def Square(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : def Pad(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): # Mode refers to 'CONSTANT', 'REFLECT' or 'SYMMETRIC' mode = 0 - if ("\"mode\"" in curNode.getAttrMapRef()): - mode = curNode.getAttrMapRef()["\"mode\""].getI() + if ("mode" in curNode.getAttrMapRef()): + mode = curNode.getAttrMapRef()["mode"].getI() constant_values = 0 - if ("\"constant_values\"" in curNode.getAttrMapRef()): - constant_values = curNode.getAttrMapRef()["\"constant_values\""].getI() + if ("constant_values" in curNode.getAttrMapRef()): + constant_values = curNode.getAttrMapRef()["constant_values"].getI() assert(mode == 0 and constant_values == 0) # For now to make life easy - deal with SYMMETRIC AND REFLECT when time comes inputsRef = curNode.getInputsRef() @@ -650,7 +650,7 @@ def Squeeze(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : inputTensorShape = extraNodeInfoDict[inputsRef[0]][0] inputTensorRank = len(inputTensorShape) - squeezeDims = curNode.getAttrMapRef()["\"squeeze_dims\""].getList().getILi() + squeezeDims = curNode.getAttrMapRef()["squeeze_dims"].getList().getILi() squeezeDimsRank = len(squeezeDims) return (None, AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], @@ -682,4 +682,4 @@ def Softmax(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])) def VarHandleOp(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - return TFNodesAST.VariableV2(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict) + return TFNodesAST.VariableV2(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict) From b54294e8c32e4342d72bb7eebe559903fab6eb6c Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 25 Nov 2020 16:46:58 +0530 Subject: [PATCH 03/72] Implement SquaredDifference We do this as a simplification on the tensorflow graph itself. We transform SquaredDifference(a,b) into (a-b) * (a-b). --- Athos/TFCompiler/Graph.py | 13 ++++++++----- Athos/TFCompiler/ProcessTFGraph.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/Athos/TFCompiler/Graph.py b/Athos/TFCompiler/Graph.py index 078a3bb5..7acbd1d3 100644 --- a/Athos/TFCompiler/Graph.py +++ b/Athos/TFCompiler/Graph.py @@ -501,11 +501,14 @@ def getList(self): return self.__val class Node: - def __init__(self): - self.__name = "" #Name of node - self.__op = "" #Name of operation carried out by node - self.__inputs = [] #List of all inputs to the current node - self.__attr = {} #Map of (attrName, Value) of all attributes for the current node + def __init__(self, op="", inputs=None, name=""): + self.__name = name #Name of node + self.__op = op #Name of operation carried out by node + if inputs is None: + self.__inputs = [] #List of all inputs to the current node + else: + self.__inputs = inputs + self.__attr = {} #Map of (attrName, Value) of all attributes for the current node def getName(self): return self.__name diff --git a/Athos/TFCompiler/ProcessTFGraph.py b/Athos/TFCompiler/ProcessTFGraph.py index d2b409b6..b47e86bf 100644 --- a/Athos/TFCompiler/ProcessTFGraph.py +++ b/Athos/TFCompiler/ProcessTFGraph.py @@ -112,6 +112,32 @@ def prefixAllPlaceHolderNodes(graph): remNodes.append(curNode) graph.setNodesList(placeHolderNodes + remNodes) + +# List of Optimisations +# 1. Split squared difference into (a-b)*(a-b) +def simplifyGraph(graph): + allNodes = graph.getAllNodesRef() + nodesMap = graph.getAllNodes() + newNodes = [] + inputsFixup = {} + for curNode in allNodes: + inputs = curNode.getInputsRef() + for i in range(len(inputs)): + if inputs[i] in inputsFixup: + inputs[i] = inputsFixup[inputs[i]] + if (curNode.getOp() == "SquaredDifference"): + sub = Graph.Node("Sub", inputs.copy(), curNode.getName() + "__sub") + mul = Graph.Node("Mul", [sub.getName(), sub.getName()], curNode.getName() + "__mul") + newNodes.append(sub) + newNodes.append(mul) + nodesMap[sub.getName()] = sub + nodesMap[mul.getName()] = mul + inputsFixup[curNode.getName()] = mul.getName() + nodesMap.pop(curNode.getName()) + else: + newNodes.append(curNode) + graph.setNodesList(newNodes) + def main(): sys.setrecursionlimit(10000) @@ -131,6 +157,8 @@ def main(): sizeInfoFileName = os.path.join(folderName, 'sizeInfo.mtdata') sizeInfo = readSizeInfo(sizeInfoFileName) + # Tensorflow graph level optimisations + simplifyGraph(graph) # Place all PlaceHolder nodes together at the beginning prefixAllPlaceHolderNodes(graph) From b15c3227f5a1d589e2b685da073abdd3af0db2a0 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 25 Nov 2020 17:07:59 +0530 Subject: [PATCH 04/72] Fix double scaledown bug for mul like ops. Squarediff exposed a bug in codegen where both inputs to mul were same. Depending on the scale of the variable at that point, we sometimes do a scaledown of the inputs of multiplication so as to maintain precision. scaledown(a, scale) scaledown(b, scale) mul(a,b) But in this case both the inputs to mul were same so we were doing scaledown(a, scale) scaledown(a, scale) mul(a,a) This led to loss of precision. Now we just do: scaledown(a, scale) mul(a,a) --- Athos/SeeDot/IR/IRBuilderCSF.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index 71db9440..9c7d69f6 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -493,12 +493,13 @@ def visitBopElemWiseOp(self, node:AST.BOp, args=None): if (Util.Config.disableTruncOpti): progExtraAfter = self.addTruncateFunctionCall(node, funcName, expr_3, Util.Config.consSF) else: + inputs_same = (expr_1.idf == expr_2.idf) expr1_sf = self.scaleFacMapping[expr_1.idf] expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, funcName, expr_1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr_1.idf] = self.scaleFac - if (expr2_sf > self.scaleFac): + if (not inputs_same) and (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, funcName, expr_2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac @@ -533,12 +534,13 @@ def visitBopMulInt(self, node:AST.BOp, args=None): if (Util.Config.disableTruncOpti): progExtraAfter = self.addTruncateFunctionCall(node, "MulInt", expr_3, Util.Config.consSF) else: + inputs_same = (expr_1.idf == expr_2.idf) expr1_sf = self.scaleFacMapping[expr_1.idf] expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, "MulInt", expr_1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr_1.idf] = self.scaleFac - if (expr2_sf > self.scaleFac): + if (not inputs_same) and (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "MulInt", expr_2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac @@ -578,12 +580,13 @@ def visitBopMulScalar1DTensor(self, node:AST.BOp, args=None): if (Util.Config.disableTruncOpti): progExtraAfter = self.addTruncateFunctionCall(node, "ScalarMul", expr_3, Util.Config.consSF) else: + inputs_same = (expr_1.idf == expr_2.idf) expr1_sf = self.scaleFacMapping[expr_1.idf] expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ScalarMul", expr_1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr_1.idf] = self.scaleFac - if (expr2_sf > self.scaleFac): + if (not inputs_same) and (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ScalarMul", expr_2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac @@ -639,12 +642,13 @@ def visitBopMul2DTensor(self, node:AST.BOp, args=None): if (Util.Config.disableTruncOpti): progExtraAfter = self.addTruncateFunctionCall(node, "MatMul2D", expr_3, Util.Config.consSF) else: + inputs_same = (expr_1.idf == expr_2.idf) expr1_sf = self.scaleFacMapping[expr_1.idf] expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, "MatMul2D", expr_1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr_1.idf] = self.scaleFac - if (expr2_sf > self.scaleFac): + if (not inputs_same) and (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "MatMul2D", expr_2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac @@ -728,12 +732,13 @@ def visitBopConv(self, node:AST.BOp, args=None): if (Util.Config.disableTruncOpti): progExtraAfter = self.addTruncateFunctionCall(node, "Conv", returnExpr, Util.Config.consSF) else: + inputs_same = (expr_1.idf == expr_2.idf) expr1_sf = self.scaleFacMapping[expr1.idf] expr2_sf = self.scaleFacMapping[expr2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, "Conv", expr1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr1.idf] = self.scaleFac - if (expr2_sf > self.scaleFac): + if (not inputs_same) and (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "Conv", expr2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac self.scaleFacMapping[returnExpr.idf] = 2*self.scaleFac @@ -833,12 +838,13 @@ def visitBopConvTranspose(self, node:AST.BOp, args=None): if (Util.Config.disableTruncOpti): progExtraAfter = self.addTruncateFunctionCall(node, "ConvTranspose", returnExpr, self.scaleFac) else: + inputs_same = (expr_1.idf == expr_2.idf) expr1_sf = self.scaleFacMapping[expr1.idf] expr2_sf = self.scaleFacMapping[expr2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ConvTranspose", expr1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr1.idf] = self.scaleFac - if (expr2_sf > self.scaleFac): + if (not inputs_same) and (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ConvTranspose", expr2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr2.idf] = self.scaleFac self.scaleFacMapping[returnExpr.idf] = 2*self.scaleFac From 988fa41d49f32aaf84fefd0e652cbd1ac7ebedac Mon Sep 17 00:00:00 2001 From: Bhatu Date: Thu, 26 Nov 2020 13:15:49 +0530 Subject: [PATCH 05/72] Automatically scale down final ouptut. Previously if there are mul like ops, the final output would be of scale = 2 * scaling_factor. Now we introduce a scale down (if required) so that the scale of the model output is = scaling_factor. --- Athos/SeeDot/AST/IRBuilderAST.py | 35 ++++++++++++++++++++++++++++++++ Athos/SeeDot/Compiler.py | 30 +++++++++++++++++++++++++++ Athos/SeeDot/IR/IRBuilderCSF.py | 3 ++- 3 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 Athos/SeeDot/AST/IRBuilderAST.py diff --git a/Athos/SeeDot/AST/IRBuilderAST.py b/Athos/SeeDot/AST/IRBuilderAST.py new file mode 100644 index 00000000..1d27e16b --- /dev/null +++ b/Athos/SeeDot/AST/IRBuilderAST.py @@ -0,0 +1,35 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2020 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' + +import AST.AST as AST +from AST.ASTVisitor import ASTVisitor + +class IRBuilderAST(ASTVisitor): + typeInfo = {} + def visit(self, node, args=None): + ret = super().visit( node, args) + if type(ret) is tuple: + if ret[1].idf not in self.typeInfo: + self.typeInfo[ret[1].idf] = node.type + return ret \ No newline at end of file diff --git a/Athos/SeeDot/Compiler.py b/Athos/SeeDot/Compiler.py index 73690712..e32f342d 100644 --- a/Athos/SeeDot/Compiler.py +++ b/Athos/SeeDot/Compiler.py @@ -29,6 +29,7 @@ import IR.IR as IR import AST.AST as AST from Writer import Writer +import Type as Type from Type import InferType import IR.IRUtil as IRUtil from AST.PrintAST import PrintAST @@ -77,6 +78,34 @@ def insertStartEndFunctionCalls(self, res:(IR.Prog, IR.Expr)): prog.cmd_l.append(IR.FuncCall('EndComputation', [])) return (prog, expr) + def fixOuputScale(self, res:(IR.Prog, IR.Expr), compiler:IRBuilderCSF): + prog = res[0] + expr = res[1] + output_scale = compiler.scaleFacMapping[expr.idf] + if output_scale == Util.Config.consSF: + return (prog, expr) + elif output_scale > Util.Config.consSF: + scale_down = output_scale - Util.Config.consSF + type = compiler.typeInfo[expr.idf] + if Type.isInt(type): + output_shape = [] + if Type.isTensor(type): + output_shape = type.shape + + argsDict = OrderedDict() + funcName = "ScaleDown" + for ii, curDimSize in enumerate(output_shape): + argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) + funcName = funcName + str(len(output_shape)) + argsDict[expr] = "expr" + argsDict[IR.Int(scale_down,32)] = "consSF" + funcCall = IR.FuncCall(funcName, argsDict) + new_prog = IR.Prog([funcCall]) + prog = IRUtil.prog_merge(prog, new_prog) + return (prog, expr) + else: + assert False, "Scale up shouldnt be required of final output. We lost precision somewhere" + def run(self): with open(Util.Config.astFile, 'rb') as ff: ast = pickle.load(ff) @@ -104,6 +133,7 @@ def run(self): IRUtil.init() compiler = IRBuilderCSF() res = compiler.visit(ast) + res = self.fixOuputScale(res, compiler); Util.write_debug_info(compiler.name_mapping) diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index 9c7d69f6..2370995e 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -33,8 +33,9 @@ import AST.AST as AST import IR.IRUtil as IRUtil from AST.ASTVisitor import ASTVisitor +from AST.IRBuilderAST import IRBuilderAST -class IRBuilderCSF(ASTVisitor): +class IRBuilderCSF(IRBuilderAST): varNameDelim = '' def __init__(self, intPartBitwidth=-1): # For tracking temp variables From 5ff67a26e9d50c488b38dd4de319875f044d9d3f Mon Sep 17 00:00:00 2001 From: Bhatu Date: Thu, 26 Nov 2020 13:21:10 +0530 Subject: [PATCH 06/72] Implement StopGradient as no-op --- Athos/TFCompiler/TFNodesAST.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index 74cc02ac..9b4fc5c0 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -678,6 +678,10 @@ def ReadVariableOp(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutV return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])) def Softmax(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])) + + def StopGradient(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])) From 3a1af19103cf8be0b9d4c83cd4776971a01b1b44 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Thu, 26 Nov 2020 13:50:29 +0530 Subject: [PATCH 07/72] Add frontend support for Tanh, Sigmoid and [R]Sqrt --- Athos/SeeDot/AST/AST.py | 8 +++++ Athos/SeeDot/IR/IRBuilderCSF.py | 55 +++++++++++++++++++++++++++------ Athos/TFCompiler/TFNodesAST.py | 24 +++++++++++--- 3 files changed, 74 insertions(+), 13 deletions(-) diff --git a/Athos/SeeDot/AST/AST.py b/Athos/SeeDot/AST/AST.py index c3521e82..199e8c0f 100644 --- a/Athos/SeeDot/AST/AST.py +++ b/Athos/SeeDot/AST/AST.py @@ -31,6 +31,10 @@ "CONV": '#', "CONVTRANSPOSE": "#T", #ConvTranspose "RELU": 'relu', + "TANH": 'tanh', + "SIGMOID": 'sigmoid', + "SQRT": 'sqrt', + "RSQRT": 'rsqrt', "Equal": '==', "ElemWiseMul":'.*', "ElemWiseDiv": './', @@ -48,6 +52,10 @@ class Operators(Enum): CONV = auto() CONVTRANSPOSE = auto() RELU = auto() + TANH = auto() + SIGMOID = auto() + SQRT = auto() + RSQRT = auto() Equal = auto() ElemWiseMul = auto() ElemWiseDiv = auto() diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index 2370995e..14a05944 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -53,6 +53,8 @@ def __init__(self, intPartBitwidth=-1): # Name mapping from SeeDot names to new names is useful for debugging self.name_mapping = {} + self.actualbitwidth = Util.Config.actualWordLength + #This is for optimizing the #truncation calls self.scaleFac = Util.Config.consSF self.bitwidth = Util.Config.wordLength @@ -856,12 +858,14 @@ def visitBopConvTranspose(self, node:AST.BOp, args=None): def visitFunc(self, node:AST.Func, args=None): op = node.op - assert(op in [AST.Operators.Floor, AST.Operators.Shape, AST.Operators.RELU, AST.Operators.ClearMemSecret, AST.Operators.ClearMemPublic]) + assert(op in [AST.Operators.Floor, AST.Operators.Shape, AST.Operators.RELU, AST.Operators.TANH, + AST.Operators.SIGMOID, AST.Operators.SQRT, AST.Operators.RSQRT, + AST.Operators.ClearMemSecret, AST.Operators.ClearMemPublic]) return self.visitFloorLike(node) def visitFloorLike(self, node:AST.Func, args=None): (prog1, expr1) = self.visit(node.expr) - tmpExpr = self.getTempVar() + out_expr = self.getTempVar() if node.op == AST.Operators.Floor: funcName = "Floor" @@ -869,6 +873,14 @@ def visitFloorLike(self, node:AST.Func, args=None): funcName = "Shape" elif node.op == AST.Operators.RELU: funcName = "Relu" + elif node.op == AST.Operators.TANH: + funcName = "Tanh" + elif node.op == AST.Operators.SIGMOID: + funcName = "Sigmoid" + elif node.op == AST.Operators.SQRT: + funcName = "Sqrt" + elif node.op == AST.Operators.RSQRT: + funcName = "Sqrt" elif node.op == AST.Operators.ClearMemSecret: funcName = "ClearMemSecret" elif node.op == AST.Operators.ClearMemPublic: @@ -885,39 +897,64 @@ def visitFloorLike(self, node:AST.Func, args=None): argsList[expr1] = "inArr" if Type.isTensor(node.type): - argsList[tmpExpr] = "outArr" + argsList[out_expr] = "outArr" if node.op == AST.Operators.Floor: argsList[IR.Int(Util.Config.consSF,32)] = "curScale" - progExtra = IR.Prog([]) + progExtraBefore = IR.Prog([]) if (Util.Config.disableTruncOpti): if node.op == AST.Operators.RELU: argsList[IR.Int(Util.Config.consSF,32)] = "consSF" argsList[IR.Bool(False)] = "doTruncation" + if node.op in [AST.Operators.TANH, AST.Operators.SIGMOID, AST.Operators.SQRT, AST.Operators.RSQRT]: + argsList[IR.Int(self.scaleFac,32)] = "sA" + argsList[IR.Int(self.scaleFac,32)] = "sB" else: final_sf = self.scaleFacMapping[expr1.idf] if node.op == AST.Operators.RELU: argsList[IR.Int(final_sf - self.scaleFac,32)] = "consSF" if (final_sf > self.scaleFac): #If it can't tolerate one more mult operation, then scale down here + assert(final_sf - self.scaleFac == self.scaleFac) final_sf = self.scaleFac argsList[IR.Bool(True)] = "doTruncation" else: argsList[IR.Bool(False)] = "doTruncation" - self.scaleFacMapping[tmpExpr.idf] = final_sf + if node.op in [AST.Operators.TANH, AST.Operators.SIGMOID, AST.Operators.SQRT, AST.Operators.RSQRT]: + # Since these class of fucntions can only handle input of 32 bitlength, we have to scale down + # inputs before calling them. + if final_sf > 32: + assert (final_sf > self.scaleFac), "The program scaling factor is invalid. Should be lesser than 32 if network has tan/sig/sqrt" + assert(final_sf - self.scaleFac == self.scaleFac) + progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr, node.op.name, expr1, final_sf - self.scaleFac)) + self.scaleFacMapping[expr1.idf] = self.scaleFac + final_sf = self.scaleFac + argsList[IR.Int(final_sf,32)] = "sA" + argsList[IR.Int(final_sf,32)] = "sB" + self.scaleFacMapping[out_expr.idf] = final_sf + + # Tanh/Sigmoid/Sqrt impl only supports upto 32 bitwidth for input + if node.op in [AST.Operators.TANH, AST.Operators.SIGMOID, AST.Operators.SQRT, AST.Operators.RSQRT]: + argsList[IR.Int(min(32, self.actualbitwidth), 32)] = "bwA" + argsList[IR.Int(self.actualbitwidth, 32)] = "bwB" + if node.op == AST.Operators.SQRT: + argsList[IR.Bool(False)] = "inverse" + if node.op == AST.Operators.RSQRT: + argsList[IR.Bool(True)] = "inverse" + argsList[IR.Int(8,32)] = "LUTBITS" comment = IR.Comment(str(node.metadata)) funcNameSuffix = "" if Type.isTensor(inputType): funcNameSuffix = str(len(inputType.shape)) - progFinal = IRUtil.prog_merge(prog1 , IR.Prog([comment, IR.FuncCall(funcName + self.varNameDelim + funcNameSuffix, argsList)])) + progFinal = IR.Prog([comment, IR.FuncCall(funcName + self.varNameDelim + funcNameSuffix, argsList)]) if Type.isTensor(node.type): - progFinal = IRUtil.prog_merge(IR.Prog([IR.Decl(tmpExpr.idf, node.type)]), progFinal) + progFinal = IRUtil.prog_merge(IR.Prog([IR.Decl(out_expr.idf, node.type)]), progFinal) - progFinal = IRUtil.prog_merge(progFinal, progExtra) - return (progFinal, tmpExpr) + progFinal = IRUtil.prog_merge(prog1, progExtraBefore, progFinal) + return (progFinal, out_expr) def visitLet(self, node:AST.Let, args=None): (prog_1, expr_1) = self.visit(node.decl) diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index 9b4fc5c0..6624c90f 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -227,6 +227,26 @@ def Relu(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : di assert(len(inputsRef)==1) return (None, AST.Func(TFNodesAST.getOperatorsIdx('relu'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))) + def Tanh(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + assert(len(inputsRef)==1) + return (None, AST.Func(TFNodesAST.getOperatorsIdx('tanh'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))) + + def Sqrt(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + assert(len(inputsRef)==1) + return (None, AST.Func(TFNodesAST.getOperatorsIdx('sqrt'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))) + + def Rsqrt(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + assert(len(inputsRef)==1) + return (None, AST.Func(TFNodesAST.getOperatorsIdx('rsqrt'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))) + + def Sigmoid(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + assert(len(inputsRef)==1) + return (None, AST.Func(TFNodesAST.getOperatorsIdx('sigmoid'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))) + def Shape(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef)==1) @@ -669,10 +689,6 @@ def BiasAdd(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) )) - def Sigmoid(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])) - def ReadVariableOp(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])) From 4c04295d63ec4044c7bced143aa8808d5b89703b Mon Sep 17 00:00:00 2001 From: Bhatu Date: Thu, 26 Nov 2020 13:56:43 +0530 Subject: [PATCH 08/72] Don't add clearmem calls for freeing scalars --- Athos/SeeDot/IR/IRBuilderCSF.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index 14a05944..89a1123e 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -888,8 +888,15 @@ def visitFloorLike(self, node:AST.Func, args=None): else: assert False + # We don't need to clear scalars. + if node.op == AST.Operators.ClearMemSecret or node.op == AST.Operators.ClearMemPublic: + if Type.isInt(node.expr.type): + return (prog1, expr1) + if node.expr.type.dim == 0: + return (prog1, expr1) + argsList = OrderedDict() - + inputType = node.expr.type if Type.isTensor(inputType): for ii, curDim in enumerate(inputType.shape): From b6b2658df441f0c9e98889328d6e9a47d02e2b44 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Thu, 26 Nov 2020 14:05:04 +0530 Subject: [PATCH 09/72] Fix codegen for unary negation of tensors --- Athos/SeeDot/IR/IRBuilderCSF.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index 89a1123e..fa0c444c 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -365,7 +365,7 @@ def visitUOp(self, node:AST.UOp, args=None): # cmdl_assn expr_1_elt = IRUtil.addIndex(expr_1, iters) expr_2_elt = IRUtil.addIndex(expr_2, iters) - cmdl_assn = IRUtil.loop(typ_2.shape, iters, [IR.Assn(expr_2_elt, IRUtil.negate(expr_1_elt))]) + cmdl_assn = IRUtil.loop(typ_2.shape, iters, [IR.Assn(expr_2_elt, IRUtil.sub(IRUtil.zero, expr_1_elt))]) comment = IR.Comment(str(node.metadata)) prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment] + cmdl_assn)) prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, node.type)]), prog_2) From 3844d8cf30d7e28ac97a7e093aaf1ed39a084a2e Mon Sep 17 00:00:00 2001 From: Bhatu Date: Thu, 26 Nov 2020 15:13:11 +0530 Subject: [PATCH 10/72] Taint Analysis added to type inference For every tensor in the graph, we want to know it's 'taint'. Each tensor can have the possible taints: Client: Input to the ML model (eg: the image input). Server: The weights of the model. ClientXServer: A tensor that is dervied after operations on both client and server tensors. Secret_constant: A tensor that is a constant but declared as a secret. Public_constant: A tensor that is a constant but declared as public. The motivation behind this analysis is to insert optimized versions of multiplication. If one input is from server and the other from model we can call ElemWiseActModelVectorMult (optimized) otherwise we can insert a call to ElemWiseSecretSharedVectorMult Matmul also expects one of its inputs to have the 'Server' taint which this analysis identifies it. --- Athos/SeeDot/AST/AST.py | 10 +- Athos/SeeDot/AST/PrintAST.py | 6 +- Athos/SeeDot/Codegen/EzPC.py | 8 +- Athos/SeeDot/Compiler.py | 6 +- Athos/SeeDot/IR/IR.py | 5 +- Athos/SeeDot/IR/IRBuilderCSF.py | 31 +++--- Athos/SeeDot/Type.py | 184 +++++++++++++++++++++++++------- 7 files changed, 181 insertions(+), 69 deletions(-) diff --git a/Athos/SeeDot/AST/AST.py b/Athos/SeeDot/AST/AST.py index 199e8c0f..bfd6d06c 100644 --- a/Athos/SeeDot/AST/AST.py +++ b/Athos/SeeDot/AST/AST.py @@ -45,6 +45,10 @@ "ClearMemPublic": 'clearmempublic' } +class Party(Enum): + SERVER = 0 + CLIENT = 1 + class Operators(Enum): ADD = auto() SUB = auto() @@ -361,12 +365,12 @@ def __init__(self, expr:ID, dim:ID, keepdims:Int, outShape:list, op: Operators): # Also, take note of the last parameter - "inputByParty". This can be used to set the party which # which will do the input for this variable. Defaults to 0, which is interpretted as SERVER by the codegen. class Input(ASTNode): - def __init__(self, shape:list, dataType:str, isSecret=True, inputByParty=0): + def __init__(self, shape:list, dataType:str, isSecret=True, inputByParty=Party.SERVER): if assertInputTypes: for elem in shape: assert isinstance(elem, int) assert isinstance(dataType, str) - assert isinstance(inputByParty, int) - assert(inputByParty==0 or inputByParty==1) #Right now EzPC supports input by two parties. + assert isinstance(inputByParty, Party) + assert(inputByParty==Party.CLIENT or inputByParty==Party.SERVER) #Right now EzPC supports input by two parties. super().__init__() self.shape = shape self.dataType = dataType diff --git a/Athos/SeeDot/AST/PrintAST.py b/Athos/SeeDot/AST/PrintAST.py index e89f1315..ac1c9f0c 100644 --- a/Athos/SeeDot/AST/PrintAST.py +++ b/Athos/SeeDot/AST/PrintAST.py @@ -89,10 +89,12 @@ def visitLet(self, node:AST.Let, args=None): node.expr.depth = node.depth + 1 print(indent * node.depth, "(", end=' ') print("let", end=' ') + if(hasattr(node.name, 'type') and hasattr(node.name.type, 'taint')): + print("<", node.decl.type.taint.name, ">",end=' ') self.visit(node.name) print("=", end=' ') self.visit(node.decl) - print("in", "{", node.metadata[AST.ASTNode.mtdKeyTFOpName], node.metadata[AST.ASTNode.mtdKeyTFNodeName], "}", end='\n') + print("{", node.metadata[AST.ASTNode.mtdKeyTFOpName], node.metadata[AST.ASTNode.mtdKeyTFNodeName], "} in ", end='\n') self.visit(node.expr) print(')',end='') @@ -113,7 +115,7 @@ def visitReduce(self, node:AST.Reduce, args=None): self.visit(node.keepdims) def visitInput(self, node:AST.Input, args=None): - print(indent * node.depth, "input( ", node.shape, node.dataType, end='') + print(indent * node.depth, "input( ", node.shape, node.dataType, " <", node.inputByParty.name, "> ", end='') print(" )", end='') def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args=None): diff --git a/Athos/SeeDot/Codegen/EzPC.py b/Athos/SeeDot/Codegen/EzPC.py index 76627228..a14681a2 100644 --- a/Athos/SeeDot/Codegen/EzPC.py +++ b/Athos/SeeDot/Codegen/EzPC.py @@ -100,12 +100,8 @@ def printInt(self, ir:IR.Int): assert False def printInput(self, ir:IR.Input): - if (ir.inputByParty==0): - inputByPartyStr = "SERVER" - elif (ir.inputByParty==1): - inputByPartyStr = "CLIENT" - else: - assert(False) #For now the only supported values of party to input is 0 or 1 + inputByPartyStr = ir.inputByParty.name + assert(inputByPartyStr == "SERVER" or inputByPartyStr == "CLIENT") #For now the only supported values of party to input is 0 or 1 self.out.printf('input({0}, {1}, '.format(inputByPartyStr, ir.expr.idf), indent=True) #assert(ir.dataType in ["DT_INT32"]) ####TODO: fix this if Util.Config.wordLength == 32: diff --git a/Athos/SeeDot/Compiler.py b/Athos/SeeDot/Compiler.py index e32f342d..2642f096 100644 --- a/Athos/SeeDot/Compiler.py +++ b/Athos/SeeDot/Compiler.py @@ -123,13 +123,13 @@ def run(self): LivenessOpti.LivenessOpti().visit(ast, [mtdAST, 0, {}]) print("Liveness optimization done.") + # Perform type inference and annotate nodes with type information + InferType().visit(ast) + if Util.Config.printASTBool: PrintAST().visit(ast) sys.stdout.flush() - # Perform type inference - InferType().visit(ast) - IRUtil.init() compiler = IRBuilderCSF() res = compiler.visit(ast) diff --git a/Athos/SeeDot/IR/IR.py b/Athos/SeeDot/IR/IR.py index 25e9d135..b08db5f3 100644 --- a/Athos/SeeDot/IR/IR.py +++ b/Athos/SeeDot/IR/IR.py @@ -26,6 +26,7 @@ import numpy as np import Util, Type +import AST.AST as AST #TODO - check if this can be cleaned up class Op(): @@ -278,7 +279,7 @@ def subst(self, from_idf:str, to_e:Expr): return self.__class__(self.name, argList_new) class Input(Cmd): - def __init__(self, expr:Expr, shape:list, dataType:str, isSecret=True, inputByParty=0): + def __init__(self, expr:Expr, shape:list, dataType:str, isSecret=True, inputByParty=AST.Party.SERVER): self.expr = expr self.shape = shape self.dataType = dataType @@ -286,7 +287,7 @@ def __init__(self, expr:Expr, shape:list, dataType:str, isSecret=True, inputByPa self.inputByParty = inputByParty def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.expr.subst(from_idf, to_e), self.shape, self.dataType, self.isSecret) + return self.__class__(self.expr.subst(from_idf, to_e), self.shape, self.dataType, self.isSecret, self.inputByParty) class Decl(Cmd): def __init__(self, varIdf:str, typeExpr:Type.Type, bitlen:int=-1, isSecret:bool=True, value:list=None): diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index fa0c444c..e3aabe97 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -63,8 +63,6 @@ def __init__(self, intPartBitwidth=-1): self.intPartBitwidth = self.bitwidth - 2*self.scaleFac self.scaleFacMapping = {} - self.inputByPartyMapping = {} - def getConsSF(self): return Util.Config.consSF @@ -115,6 +113,12 @@ def addTruncateFunctionCall(self, node:AST.ASTNode, nodeTypeStr: str, expr:IR.Va prog = IR.Prog([comment, funcCall]) return prog + def isModel(self, node:AST.ASTNode): + if(node.type.taint == Type.Taints.SERVER): + return True + else: + return False + #================= # Visit Functions #================= @@ -625,19 +629,16 @@ def visitBopMul2DTensor(self, node:AST.BOp, args=None): funcCallArgsDict[expr_2] = "B" funcCallArgsDict[expr_3] = "C" - #Add an arg as to which arg out of A or B is the model - # This is ok, since Athos is right now tailored for neural network inference - # and in inference, in every linear layer, either of A or B will be the model. - # This is required because for some backends, knowing which of A or B is the model - # can make a difference in their performance. + # Add an arg as to which arg out of A or B is a model weight + # This is ok, since Athos is right now tailored for neural network inference + # and in inference, in every linear layer, either of A or B will be a model weight. + # This is required because for some backends, knowing which of A or B is a model weight + # can make a difference in their performance. modelIsA = True - if ((expr_1.idf in self.inputByPartyMapping) and (self.inputByPartyMapping[expr_1.idf]==0)): - modelIsA = True - elif ((expr_2.idf in self.inputByPartyMapping) and (self.inputByPartyMapping[expr_2.idf]==0)): + assert (self.isModel(node.expr1) or self.isModel(node.expr2)), "Expecting one of A or B to be an input by the server (model weight)." + modelIsA = True + if (not self.isModel(node.expr1)): modelIsA = False - else: - print("Expecting one of A or B to be the model, input by the server.") - assert(False) funcCallArgsDict[IR.Bool(modelIsA)] = "modelIsA" progExtraBefore = IR.Prog([]) @@ -971,9 +972,6 @@ def visitLet(self, node:AST.Let, args=None): self.name_mapping[idf] = expr_1.idf if (not(Util.Config.disableTruncOpti)): self.scaleFacMapping[idf] = self.scaleFacMapping[expr_1.idf] - if (expr_1.idf in self.inputByPartyMapping): - self.inputByPartyMapping[idf] = self.inputByPartyMapping[expr_1.idf] - del self.inputByPartyMapping[expr_1.idf] (prog_2, expr_2) = self.visit(node.expr) prog_2 = prog_2.subst(idf, expr_1) expr_2 = expr_2.subst(idf, expr_1) @@ -1083,7 +1081,6 @@ def visitInput(self, node:AST.Input, args=None): comment = IR.Comment(str(node.metadata)) if not(Util.Config.disableTruncOpti): self.scaleFacMapping[returnExpr.idf] = self.scaleFac - self.inputByPartyMapping[returnExpr.idf] = node.inputByParty return (IR.Prog([comment, IR.Input(returnExpr, node.shape, node.dataType, node.isSecret, node.inputByParty)]), returnExpr) def visitReduce(self, node:AST.Reduce, args=None): diff --git a/Athos/SeeDot/Type.py b/Athos/SeeDot/Type.py index 402a32a1..914ba9de 100644 --- a/Athos/SeeDot/Type.py +++ b/Athos/SeeDot/Type.py @@ -27,23 +27,103 @@ from functools import reduce import AST.AST as AST from AST.ASTVisitor import ASTVisitor +from enum import Enum, auto +import copy class Type: pass +''' +We want to analyse the taint of every tensor that flows in the graph. +The possible taints for tensors are: +{ + Client: Input to the ML model (eg: the image input) + Server: The weights of the model + ClientXServer[C&S]: A tensor that is dervied after operations on both client and server tensors. + Secret_constant: A tensor that is a constant but declared as a secret + Public_constant: A tensor that is a constant but declared as public +} +Note: For ML models we don't expect to encounter any secret_constants and instead expect them +to be encoded as weights of the model and so instead has the server taint. + +We infer taints in the following manner: + Client Server C&S Secret_constant Public_constant +Client Client C&S C&S Client Client +Server C&S Server C&S Server Server +C&S C&S C&S C&S C&S C&S +Secret_constant C&S C&S C&S Secret_constant Secret_constant +Public_constant Client Server C&S Secret_constant Public_constant +''' + +class Taints(Enum): + CLIENT = auto() + SERVER = auto() + CLIENT_SERVER = auto() + SECRET_C = auto() + PUBLIC_C = auto() + +constantTaintsMapping = { True : Taints.SECRET_C, False : Taints.PUBLIC_C} + +TaintsTable = { + Taints.CLIENT : { + Taints.CLIENT : Taints.CLIENT, + Taints.SERVER : Taints.CLIENT_SERVER, + Taints.CLIENT_SERVER: Taints.CLIENT_SERVER, + Taints.SECRET_C: Taints.CLIENT, + Taints.PUBLIC_C: Taints.CLIENT + }, + Taints.SERVER : { + Taints.CLIENT : Taints.CLIENT_SERVER, + Taints.SERVER : Taints.SERVER, + Taints.CLIENT_SERVER: Taints.CLIENT_SERVER, + Taints.SECRET_C: Taints.SERVER, + Taints.PUBLIC_C: Taints.SERVER + }, + Taints.CLIENT_SERVER : { + Taints.CLIENT : Taints.CLIENT_SERVER, + Taints.SERVER : Taints.CLIENT_SERVER, + Taints.CLIENT_SERVER: Taints.CLIENT_SERVER, + Taints.SECRET_C: Taints.CLIENT_SERVER, + Taints.PUBLIC_C: Taints.CLIENT_SERVER + }, + Taints.SECRET_C : { + Taints.CLIENT : Taints.CLIENT, + Taints.SERVER : Taints.SERVER, + Taints.CLIENT_SERVER: Taints.CLIENT_SERVER, + Taints.SECRET_C: Taints.SECRET_C, + Taints.PUBLIC_C: Taints.SECRET_C + }, + Taints.PUBLIC_C : { + Taints.CLIENT : Taints.CLIENT, + Taints.SERVER : Taints.SERVER, + Taints.CLIENT_SERVER: Taints.CLIENT_SERVER, + Taints.SECRET_C: Taints.SECRET_C, + Taints.PUBLIC_C: Taints.PUBLIC_C + } + } +def getTaint_taint(t1: Taints, t2: Taints): + return TaintsTable[t1][t2] + +def getTaint_type(t1: Type, t2: Type): + return TaintsTable[t1.taint][t2.taint] + class Int(Type): - def __init__(self, bitlen=-1, isSecret=False): + def __init__(self, bitlen=-1, isSecret=False, taint=Taints.PUBLIC_C): if bitlen==-1: self.bitlen = Util.Config.wordLength else: self.bitlen = bitlen self.isSecret = isSecret + self.taint = taint + + def __copy__(self): + return type(self)(self.bitlen, self.isSecret, self.taint) class Unit(Type): pass class Tensor(Type): - def __init__(self, shape:list, bitlen=-1, isSecret=True): + def __init__(self, shape:list, bitlen=-1, isSecret=True, taint=Taints.PUBLIC_C): self.shape = shape self.dim = len(shape) if bitlen==-1: @@ -51,6 +131,10 @@ def __init__(self, shape:list, bitlen=-1, isSecret=True): else: self.bitlen = bitlen self.isSecret = isSecret + self.taint = taint + + def __copy__(self): + return type(self)(self.shape, self.bitlen, self.isSecret, self.taint) def size(self): return reduce(operator.mul, self.shape, 1) @@ -83,12 +167,12 @@ def visitInt(self, node:AST.Int, args=None): bitlen = Util.Config.wordLength if node.bitLen: bitlen = node.bitLen - node.type = Int(bitlen, node.isSecret) + node.type = Int(bitlen, node.isSecret, constantTaintsMapping[node.isSecret]) return node.type def visitFloat(self, node:AST.Float, args=None): - # Float is represented as a tensor with 0 dimension - node.type = Tensor([], -1, node.isSecret) + # Float is represented as an int in fixedpt. + node.type = Int(isSecret=node.isSecret, taint=constantTaintsMapping[node.isSecret]) return node.type def visitId(self, node:AST.ID, args=None): @@ -101,7 +185,10 @@ def visitId(self, node:AST.ID, args=None): def visitDecl(self, node:AST.Decl, args=None): #TODO -- fill in bitlen properly - node.type = Tensor(node.shape, -1, node.isSecret) + if (node.shape == []): + node.type = Int(isSecret=node.isSecret, taint=constantTaintsMapping[node.isSecret]) + else: + node.type = Tensor(shape=node.shape, isSecret=node.isSecret, taint=constantTaintsMapping[node.isSecret]) return node.type def visitTranspose(self, node:AST.Transpose, args=None): @@ -117,7 +204,7 @@ def visitTranspose(self, node:AST.Transpose, args=None): new_shape = [] for i in perm: new_shape.append(shape[i]) - node.type = Tensor(new_shape) + node.type = Tensor(new_shape, exprType.bitlen, exprType.isSecret, exprType.taint) return node.type def visitReshape(self, node:AST.Reshape, args=None): @@ -128,7 +215,7 @@ def visitReshape(self, node:AST.Reshape, args=None): # Reshape is valid if the total number of elements remain same after reshape assert reduce(operator.mul, exprType.shape, 1) == reduce(operator.mul, node.shape, 1) - node.type = Tensor(node.shape) + node.type = Tensor(node.shape, exprType.bitlen, exprType.isSecret, exprType.taint) return node.type @@ -151,7 +238,7 @@ def visitPool(self, node:AST.Pool, args=None): newH = ((H + zPadHLeft + zPadHRight - FH)//strideH) + 1 newW = ((W + zPadWLeft + zPadWRight - FW)//strideW) + 1 - node.type = Tensor([N, newH, newW, CI]) + node.type = Tensor([N, newH, newW, CI], exprType.bitlen, exprType.isSecret, exprType.taint) return node.type @@ -202,7 +289,7 @@ def typeCheckBroadcastOps(self, node:AST.BOp, eType:Type, fType:Type): assert len(eType.shape) >= len(fType.shape) if isInt(eType) and isInt(fType): - node.type = Int(eType.bitlen, eType.isSecret) + node.type = Int(eType.bitlen) elif isTensor(eType) and isTensor(fType): revETypeShape = eType.shape[::-1] revFTypeShape = fType.shape[::-1] @@ -214,10 +301,13 @@ def typeCheckBroadcastOps(self, node:AST.BOp, eType:Type, fType:Type): assert False # Broadcast possible - node.type = eType + node.type = copy.copy(eType) else: print(eType, fType) assert False + + node.type.taint = getTaint_type(eType, fType) + node.type.isSecret = eType.isSecret | fType.isSecret return node.type def visitBopMul(self, node:AST.BOp, eType:Type, fType:Type, args=None): @@ -225,19 +315,22 @@ def visitBopMul(self, node:AST.BOp, eType:Type, fType:Type, args=None): node.type = Int(eType.bitlen, eType.isSecret) elif isTensor(eType) and isTensor(fType): if eType.dim == 0: - node.type = fType + node.type = copy.copy(fType) elif fType.dim == 0: - node.type = eType + node.type = copy.copy(eType) else: assert eType.dim == 2 and fType.dim == 2 [n1, n2] = eType.shape [n3, n4] = fType.shape assert n2 == n3 - node.type = Tensor([n1, n4]) + node.type = Tensor([n1, n4], eType.bitlen) else: print("Error: Unknown condition in type checking.", file=sys.stderr) assert(False) + node.type.taint = getTaint_type(eType, fType) + node.type.isSecret = eType.isSecret | fType.isSecret + return node.type def visitBopConv(self, node:AST.BOp, eType:Type, fType:Type, args=None): @@ -291,7 +384,7 @@ def visitBopConv(self, node:AST.BOp, eType:Type, fType:Type, args=None): shape = [N, newH, newW, CO] elif convDim == 3: shape = [N, newD, newH, newW, CO] - node.type = Tensor(shape) + node.type = Tensor(shape, eType.bitlen, eType.isSecret | fType.isSecret, getTaint_type(eType, fType)) return node.type def visitBopConvTranspose(self, node:AST.BOp, eType:Type, fType:Type, args=None): @@ -328,7 +421,7 @@ def visitBopConvTranspose(self, node:AST.BOp, eType:Type, fType:Type, args=None) # of size shape = [N, outputImgH, outputImgW, CI], and filter of size [FH, FW, CI, CO]. # Hence, the input for this convTranspose would be [N, HP, WP, CO] - node.type = Tensor(shape) + node.type = Tensor(shape, eType.bitlen, eType.isSecret | fType.isSecret, getTaint_type(eType, fType)) return node.type def visitBopAddLike(self, node:AST.BOp, eType: Type, fType: Type, args=None): @@ -339,7 +432,9 @@ def visitBopAddLike(self, node:AST.BOp, eType: Type, fType: Type, args=None): else: assert False - node.type = eType + node.type = copy.copy(eType) + node.type.taint = getTaint_type(eType, fType) + node.type.isSecret = eType.isSecret | fType.isSecret return node.type def visitFunc(self, node:AST.Func, args=None): @@ -348,18 +443,30 @@ def visitFunc(self, node:AST.Func, args=None): if node.op == AST.Operators.RELU: assert isTensor(eType) and eType.dim >= 1 - node.type = eType + node.type = copy.copy(eType) + elif node.op == AST.Operators.TANH: + assert isTensor(eType) + node.type = copy.copy(eType) + elif node.op == AST.Operators.SIGMOID: + assert isTensor(eType) + node.type = copy.copy(eType) + elif node.op == AST.Operators.SQRT: + assert isTensor(eType) + node.type = copy.copy(eType) + elif node.op == AST.Operators.RSQRT: + assert isTensor(eType) + node.type = copy.copy(eType) elif node.op == AST.Operators.Floor: - node.type = eType + node.type = copy.copy(eType) elif node.op == AST.Operators.Shape: assert isTensor(eType) - node.type = Tensor([len(eType.shape)]) + node.type = Tensor([len(eType.shape)], eType.bitlen, eType.isSecret, eType.taint) elif node.op == AST.Operators.ClearMemSecret: node.type = Unit() elif node.op == AST.Operators.ClearMemPublic: node.type = Unit() else: - print(node.op) + print("Type inference not implemented for", node.op) assert False return node.type @@ -368,49 +475,51 @@ def visitLet(self, node:AST.Let, args=None): node.decl.gamma = dict(node.gamma) eType = self.visit(node.decl) + node.name.gamma = { node.name.name : eType} + self.visit(node.name) + node.expr.gamma = dict(node.gamma) node.expr.gamma[node.name.name] = eType fType = self.visit(node.expr) - node.type = fType + node.type = copy.copy(fType) return node.type def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args=None): # Assert that outputShape and inputDims are lists of int astNode. assert(len(node.argsList) > 0) + isSecret = False + taint = Taints.PUBLIC_C for curArg in node.argsList: curArg.gamma = dict(node.gamma) - self.visit(curArg) #This should set the type of each of the input nodes + eType = self.visit(curArg) #This should set the type of each of the input nodes + isSecret = isSecret | eType.isSecret + taint = getTaint_taint(taint, eType.taint) outputShape = node.outputShape - node.type = Tensor(outputShape) + node.type = Tensor(outputShape, isSecret=isSecret, taint=taint) return node.type def visitArgMax(self, node:AST.ArgMax, args=None): node.expr.gamma = dict(node.gamma) - inpTensorType = self.visit(node.expr) + eType = self.visit(node.expr) node.dim.gamma = dict(node.gamma) dimType = self.visit(node.dim) assert(isInt(dimType) or (isTensor(dimType) and (len(dimType.shape)==0))) - node.type = Tensor(node.outputShape) + node.type = Tensor(node.outputShape, eType.bitlen, eType.isSecret, eType.taint) return node.type def visitReduce(self, node:AST.Reduce, args=None): cur_gamma = dict(node.gamma) node.expr.gamma = cur_gamma - node.dim.gamma = cur_gamma - node.keepdims.gamma = cur_gamma - - self.visit(node.expr) - self.visit(node.dim) - self.visit(node.keepdims) + eType = self.visit(node.expr) - node.type = Tensor(node.outShape) + node.type = Tensor(node.outShape, eType.bitlen, eType.isSecret, eType.taint) return node.type def visitInput(self, node:AST.Input, args=None): - node.type = Tensor(node.shape) + node.type = Tensor(node.shape, isSecret=node.isSecret, taint=Taints[node.inputByParty.name]) return node.type def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args=None): @@ -431,6 +540,9 @@ def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args=None): assert(exprType.shape[-1]==C1 and C1==C2) - node.type = exprType - return node.type + taint = getTaint_taint(exprType.taint, multExprType.taint) + taint = getTaint_taint(taint, addExprType.taint) + node.type = copy.copy(exprType) + node.type.taint = taint + return node.type \ No newline at end of file From f416af1ec7004d12aa1a7fdc55fa74c4ad5a1046 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Thu, 26 Nov 2020 16:41:28 +0530 Subject: [PATCH 11/72] Support for Split operation We support splitting of a tensor along an axis into n pieces, where n has to be a constant. Eg: Split(Tensor of shape(5,30), splits=3, axis=1) returns 3 tensors of shape(5,10) each. Currently we do not suport splitting into tensors of specified shape (num_or_size_splits) though that functionality will be added later. We also do not support splitting into n pieces where n is a runtime value because we do not support run-time code generation yet. This also adds support in the frontend for an op to return multiple values. --- Athos/SeeDot/AST/AST.py | 15 +- Athos/SeeDot/AST/ASTVisitor.py | 5 + Athos/SeeDot/AST/MtdAST.py | 4 + Athos/SeeDot/AST/PrintAST.py | 6 + Athos/SeeDot/IR/IRBuilderCSF.py | 31 ++++ Athos/SeeDot/Optimizations/LivenessOpti.py | 5 + Athos/SeeDot/Type.py | 20 +++ Athos/TFCompiler/ProcessTFGraph.py | 50 +++--- Athos/TFCompiler/TFNodesAST.py | 168 ++++++++++++--------- 9 files changed, 208 insertions(+), 96 deletions(-) diff --git a/Athos/SeeDot/AST/AST.py b/Athos/SeeDot/AST/AST.py index bfd6d06c..1ecc85d7 100644 --- a/Athos/SeeDot/AST/AST.py +++ b/Athos/SeeDot/AST/AST.py @@ -196,6 +196,19 @@ def __init__(self, expr: ASTNode, perm: list = None): self.expr = expr self.perm = perm +# expr : ASTNode, perm : list of ints +class Slice(ASTNode): + def __init__(self, expr: ASTNode, subscriptRanges: list = None): + if assertInputTypes: + assert isinstance(expr, ID) + if subscriptRanges: + for elem in subscriptRanges: + assert isinstance(elem[0], int) + assert isinstance(elem[1], int) + super().__init__() + self.expr = expr + self.subscriptRanges = subscriptRanges + # expr : ASTNode, shape : list of int, order : int : optional class Reshape(ASTNode): def __init__(self, expr: ASTNode, shape: list, order: list): @@ -363,7 +376,7 @@ def __init__(self, expr:ID, dim:ID, keepdims:Int, outShape:list, op: Operators): # NOTE: Though datatype is being passed to this function, the output code eventually only has # int in the apt bitlen for which the whole compilation is done # Also, take note of the last parameter - "inputByParty". This can be used to set the party which -# which will do the input for this variable. Defaults to 0, which is interpretted as SERVER by the codegen. +# which will do the input for this variable. Defaults to SERVER. class Input(ASTNode): def __init__(self, shape:list, dataType:str, isSecret=True, inputByParty=Party.SERVER): if assertInputTypes: diff --git a/Athos/SeeDot/AST/ASTVisitor.py b/Athos/SeeDot/AST/ASTVisitor.py index fbed995d..1f4d6632 100644 --- a/Athos/SeeDot/AST/ASTVisitor.py +++ b/Athos/SeeDot/AST/ASTVisitor.py @@ -42,6 +42,9 @@ def visitDecl(self, node:AST.Decl, args=None): def visitTranspose(self, node:AST.Transpose, args=None): self.visit(node.expr, args) + def visitSlice(self, node:AST.Slice, args=None): + self.visit(node.expr, args) + def visitReshape(self, node:AST.Reshape, args=None): self.visit(node.expr, args) @@ -97,6 +100,8 @@ def visit(self, node, args=None): return self.visitDecl(node, args) elif isinstance(node, AST.Transpose): return self.visitTranspose(node, args) + elif isinstance(node, AST.Slice): + return self.visitSlice(node, args) elif isinstance(node, AST.Reshape): return self.visitReshape(node, args) elif isinstance(node, AST.Pool): diff --git a/Athos/SeeDot/AST/MtdAST.py b/Athos/SeeDot/AST/MtdAST.py index 3d209556..e9d4614e 100644 --- a/Athos/SeeDot/AST/MtdAST.py +++ b/Athos/SeeDot/AST/MtdAST.py @@ -42,6 +42,10 @@ def visitTranspose(self, node:AST.Transpose, mtd:dict): node.metadata.update(mtd) self.visit(node.expr, mtd) + def visitSlice(self, node:AST.Slice, mtd:dict): + node.metadata.update(mtd) + self.visit(node.expr, mtd) + def visitReshape(self, node:AST.Reshape, mtd:dict): node.metadata.update(mtd) self.visit(node.expr, mtd) diff --git a/Athos/SeeDot/AST/PrintAST.py b/Athos/SeeDot/AST/PrintAST.py index ac1c9f0c..84025491 100644 --- a/Athos/SeeDot/AST/PrintAST.py +++ b/Athos/SeeDot/AST/PrintAST.py @@ -51,6 +51,12 @@ def visitTranspose(self, node:AST.Transpose, args=None): self.visit(node.expr) print("^Transpose", end=' ') + def visitSlice(self, node:AST.Transpose, args=None): + node.expr.depth = node.depth + 1 + print(indent * node.depth, end=' ') + self.visit(node.expr) + print("extract slice", end=' ') + def visitReshape(self, node:AST.Reshape, args=None): node.expr.depth = node.depth + 1 print(indent * node.depth, "reshape", end=' ') diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index e3aabe97..c315b291 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -231,6 +231,37 @@ def visitTranspose(self, node:AST.Transpose, args=None): return (final_prog, out_arr) + def visitSlice(self, node:AST.Slice, args=None): + (inp_prog, inp_arr) = self.visit(node.expr) + inp_type = node.expr.type + out_type = node.type + out_iters = self.getTempIterators(out_type.dim) + inp_iters = [] + subscriptRanges = node.subscriptRanges + for idx,subrange in enumerate(subscriptRanges): + start = subrange[0] + inp_iters.append(IRUtil.add(out_iters[idx], IR.Int(start))) + + out_arr = self.getTempVar() + out_arr_expr = IRUtil.addIndex(out_arr, out_iters) + inp_arr_expr = IRUtil.addIndex(inp_arr, inp_iters) + assign_expr = IR.Assn(out_arr_expr, inp_arr_expr) + loop = IRUtil.loop(out_type.shape, out_iters, [assign_expr]) + # Finalize + comment1 = IR.Comment(str(node.metadata)) + comment2 = IR.Comment("slice(" + inp_arr.idf + ", [" + ', '.join(str(e) for e in inp_type.shape) + "] --> [" + ', '.join(str(e) for e in out_type.shape) + "])") + slice_prog = IR.Prog([comment1, comment2] + loop) + final_prog = IRUtil.prog_merge(inp_prog, slice_prog) + + for var in out_iters: + final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), final_prog) + final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(out_arr.idf, out_type)]), final_prog) + + if not(Util.Config.disableTruncOpti): + self.scaleFacMapping[out_arr.idf] = self.scaleFacMapping[inp_arr.idf] + + return (final_prog, out_arr) + def visitReshape(self, node:AST.Reshape, args=None): (prog_1, expr_1) = self.visit(node.expr) diff --git a/Athos/SeeDot/Optimizations/LivenessOpti.py b/Athos/SeeDot/Optimizations/LivenessOpti.py index 7a8ebc78..161881f5 100644 --- a/Athos/SeeDot/Optimizations/LivenessOpti.py +++ b/Athos/SeeDot/Optimizations/LivenessOpti.py @@ -52,6 +52,11 @@ def visitTranspose(self, node:AST.Transpose, args): node.optidict[self.optidictKey] = unboundVars return unboundVars + def visitSlice(self, node:AST.Slice, args): + unboundVars = self.visit(node.expr, args) + node.optidict[self.optidictKey] = unboundVars + return unboundVars + def visitReshape(self, node:AST.Reshape, args): unboundVars = self.visit(node.expr, args) node.optidict[self.optidictKey] = unboundVars diff --git a/Athos/SeeDot/Type.py b/Athos/SeeDot/Type.py index 914ba9de..0420e751 100644 --- a/Athos/SeeDot/Type.py +++ b/Athos/SeeDot/Type.py @@ -207,6 +207,26 @@ def visitTranspose(self, node:AST.Transpose, args=None): node.type = Tensor(new_shape, exprType.bitlen, exprType.isSecret, exprType.taint) return node.type + def visitSlice(self, node:AST.Slice, args=None): + node.expr.gamma = dict(node.gamma) + exprType = self.visit(node.expr) + assert isTensor(exprType) + + subscriptRanges = node.subscriptRanges + shape = [] + for i in subscriptRanges: + start = i[0] + end = i[1] + size = end - start + 1 + shape.append(size) + + assert(len(shape) == len(exprType.shape)) + for i in range(0,len(shape)): + assert(shape[i] <= exprType.shape[i]) + + node.type = Tensor(shape, exprType.bitlen, exprType.isSecret, exprType.taint) + return node.type + def visitReshape(self, node:AST.Reshape, args=None): node.expr.gamma = dict(node.gamma) exprType = self.visit(node.expr) diff --git a/Athos/TFCompiler/ProcessTFGraph.py b/Athos/TFCompiler/ProcessTFGraph.py index b47e86bf..915ef1e5 100644 --- a/Athos/TFCompiler/ProcessTFGraph.py +++ b/Athos/TFCompiler/ProcessTFGraph.py @@ -37,8 +37,8 @@ def generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDic curNodeOp = curNode.getOp() ast = None func = getattr(TFNodesAST, curNodeOp) - (assignedVarAST, curAST) = func(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict) - return (assignedVarAST, curAST) + (assignedVarAST, curASTs) = func(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict) + return (assignedVarAST, curASTs) #Takes the graph DS and outputs IR in SeeDot for the same def generateIRCode(graph, extraInfoDict): @@ -51,29 +51,29 @@ def generateIRCode(graph, extraInfoDict): for curNode in graph.getAllNodesRef(): for curInp in curNode.getInputsRef(): assert(curInp in dictNodeNameToOutVarStr) #Consequence of topological sorting of the TF graph - (assignedVarAST, curAst) = generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraInfoDict) - - mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : curNode.getOp(), - AST.ASTNode.mtdKeyTFNodeName : curNode.getName()} - - if (curAst is None): - dictNodeNameToOutVarStr[curNode.getName()] = None - continue - curOutVarStr = outVarPrefix + str(outVarCt) - curOutVarAstNode = (assignedVarAST if assignedVarAST else AST.ID(curOutVarStr)) - if program: - assert(type(innerMostLetASTNode) is AST.Let) - newNode = AST.Let(curOutVarAstNode, curAst, curOutVarAstNode) - mtdAST.visit(newNode, mtdForCurAST) - innerMostLetASTNode.expr = newNode - innerMostLetASTNode = newNode - else: - innerMostLetASTNode = AST.Let(AST.ID(curOutVarStr), curAst, curOutVarAstNode) - mtdAST.visit(innerMostLetASTNode, mtdForCurAST) - innerMostLetASTNode.depth = 0 - program = innerMostLetASTNode - dictNodeNameToOutVarStr[curNode.getName()] = curOutVarStr - outVarCt += 1 + (assignedVarAST, curAsts) = generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraInfoDict) + for outputName, curAst in curAsts.items(): + mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : curNode.getOp(), + AST.ASTNode.mtdKeyTFNodeName : outputName} + + if (curAst is None): + dictNodeNameToOutVarStr[outputName] = None + continue + curOutVarStr = outVarPrefix + str(outVarCt) + curOutVarAstNode = (assignedVarAST if assignedVarAST else AST.ID(curOutVarStr)) + if program: + assert(type(innerMostLetASTNode) is AST.Let) + newNode = AST.Let(curOutVarAstNode, curAst, curOutVarAstNode) + mtdAST.visit(newNode, mtdForCurAST) + innerMostLetASTNode.expr = newNode + innerMostLetASTNode = newNode + else: + innerMostLetASTNode = AST.Let(AST.ID(curOutVarStr), curAst, curOutVarAstNode) + mtdAST.visit(innerMostLetASTNode, mtdForCurAST) + innerMostLetASTNode.depth = 0 + program = innerMostLetASTNode + dictNodeNameToOutVarStr[outputName] = curOutVarStr + outVarCt += 1 return (program, dictNodeNameToOutVarStr) def readSizeInfo(fileName): diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index 6624c90f..530f4775 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -74,7 +74,7 @@ def MatMul(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : transposeBBool = attrMapRef["transpose_b"].getB() if (transposeABool): inp1AST = AST.Transp(inp1AST) if (transposeBBool): inp2AST = AST.Transp(inp2AST) - return (None, AST.BOp(inp1AST, TFNodesAST.getOperatorsIdx('*'), inp2AST)) + return (None, { curNode.getName() : AST.BOp(inp1AST, TFNodesAST.getOperatorsIdx('*'), inp2AST)}) def Placeholder(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] @@ -87,15 +87,15 @@ def Placeholder(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarS # all model parameters are represented using Variable op nodes. # Hence, in the call to AST.Input, we pass inputByParty=1. - return (None, AST.Input(curNodeShapeLi, curNodeInputType.name, isSecret=True, inputByParty=1)) + return (None, { curNode.getName() : AST.Input(curNodeShapeLi, curNodeInputType.name, isSecret=True, inputByParty=AST.Party.CLIENT)}) def Equal(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 2) - return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), TFNodesAST.getOperatorsIdx('=='), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - )) + )}) def Identity(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): #In SeeDot, J2=J1 creates a new reference for J1 -- so @@ -111,59 +111,59 @@ def Identity(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr retAST = AST.UninterpFuncCall(curNodeShape, TFNodesAST.UninterpFuncCallNames.CreateIdentity.name, [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])]) - return (None, retAST) + return (None, { curNode.getName() : retAST}) def Add(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 2) - return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), TFNodesAST.getOperatorsIdx('+'), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - )) + )}) def AddV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 2) - return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), TFNodesAST.getOperatorsIdx('+'), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - )) + )}) def Mul(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 2) - return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), TFNodesAST.getOperatorsIdx('.*'), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - )) + )}) def Neg(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 1) - return (None, AST.UOp(TFNodesAST.getOperatorsIdx('-'), + return (None, { curNode.getName() : AST.UOp(TFNodesAST.getOperatorsIdx('-'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]) - )) + )}) def Sub(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 2) - return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), TFNodesAST.getOperatorsIdx('+'), AST.UOp(TFNodesAST.getOperatorsIdx('-'), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - ))) + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) + ))}) def Floor(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 1) - return (None, AST.Func(TFNodesAST.getOperatorsIdx('floor'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))) + return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('floor'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))}) def RealDiv(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 2) - return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), TFNodesAST.getOperatorsIdx('./'), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - )) + )}) def FloorDiv(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -172,7 +172,7 @@ def FloorDiv(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr TFNodesAST.getOperatorsIdx('./'), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) ) - return (None, AST.Func(TFNodesAST.getOperatorsIdx('floor'), realDivAST)) + return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('floor'), realDivAST)}) def VariableV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): curNodeShapeLi = curNode.getAttrMapRef()["shape"].getShape().getDimRef()[:] @@ -183,8 +183,8 @@ def VariableV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarSt # (in the scenario of secure inference), model is input by server and image by client. # We assume in the following that the PlaceHolder op node represents the image and # all model parameters are represented using Variable op nodes. - # Hence, in the call to AST.Input, we pass inputByParty=0. - return (None, AST.Input(curNodeShapeLi, curNodeInputType.name, isSecret=True, inputByParty=0)) + # Hence, in the call to AST.Input, we pass inputByParty as SERVER. + return (None, { curNode.getName() : AST.Input(curNodeShapeLi, curNodeInputType.name, isSecret=True, inputByParty=AST.Party.SERVER)}) def Const(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): assert(len(curNode.getInputsRef()) == 0) @@ -220,49 +220,49 @@ def Const(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : d else: assert False retAST = AST.Decl(curNodeShape, None, dataPassed, isSecret=False) - return (None, retAST) + return (None, { curNode.getName() : retAST}) def Relu(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef)==1) - return (None, AST.Func(TFNodesAST.getOperatorsIdx('relu'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))) + return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('relu'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))}) def Tanh(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef)==1) - return (None, AST.Func(TFNodesAST.getOperatorsIdx('tanh'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))) + return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('tanh'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))}) def Sqrt(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef)==1) - return (None, AST.Func(TFNodesAST.getOperatorsIdx('sqrt'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))) + return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('sqrt'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))}) def Rsqrt(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef)==1) - return (None, AST.Func(TFNodesAST.getOperatorsIdx('rsqrt'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))) + return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('rsqrt'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))}) def Sigmoid(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef)==1) - return (None, AST.Func(TFNodesAST.getOperatorsIdx('sigmoid'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))) + return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('sigmoid'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))}) def Shape(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef)==1) - return (None, AST.Func(TFNodesAST.getOperatorsIdx('shape'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))) + return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('shape'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))}) def Cast(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 1) sourceType = curNode.getAttrMapRef()["SrcT"].getDataType() destType = curNode.getAttrMapRef()["DstT"].getDataType() - return (None, AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], + return (None, { curNode.getName() : AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.Cast.name, [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), AST.ID(sourceType.name), AST.ID(destType.name) - ])) + ])}) def ZerosLike(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -273,7 +273,7 @@ def ZerosLike(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr TFNodesAST.UninterpFuncCallNames.CreateTensor.name, [AST.Int(0, isSecret=False)], isSecret=False) - return (None, retAST) + return (None, { curNode.getName() : retAST}) def Fill(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -288,12 +288,12 @@ def Fill(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : di TFNodesAST.UninterpFuncCallNames.CreateTensor.name, [AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) ], isSecret=False) - return (None, retAST) + return (None, { curNode.getName() : retAST}) def Reshape(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 2) - return (None, AST.Reshape(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), extraNodeInfoDict[curNode.getName()][0], None)) + return (None, { curNode.getName() : AST.Reshape(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), extraNodeInfoDict[curNode.getName()][0], None)}) def helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD = None, FD = None, strideD = None): if imgD: @@ -374,10 +374,10 @@ def Conv2D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : options[AST.PaddingKeysDict.zPadWRight] = zPadWRight options[AST.PaddingKeysDict.strideH] = strideH options[AST.PaddingKeysDict.strideW] = strideW - return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), TFNodesAST.getOperatorsIdx('#'), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), - options)) + options)}) def Conv3D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -417,10 +417,10 @@ def Conv3D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : options[AST.PaddingKeysDict.strideH] = strideH options[AST.PaddingKeysDict.strideW] = strideW options[AST.PaddingKeysDict.ConvDim] = 3 - return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), TFNodesAST.getOperatorsIdx('#'), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), - options)) + options)}) def Conv3DBackpropInputV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -472,10 +472,10 @@ def Conv3DBackpropInputV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNam options[AST.PaddingKeysDict.outputImgD] = outputD options[AST.PaddingKeysDict.outputImgH] = outputH options[AST.PaddingKeysDict.outputImgW] = outputW - return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), + return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), TFNodesAST.getOperatorsIdx('#T'), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), - options)) + options)}) def helper_processPool(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict, typeOfPool:str): inputsRef = curNode.getInputsRef() @@ -510,7 +510,7 @@ def helper_processPool(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameTo else: print("Unknown type of pooling layer.", file=sys.stderr) assert(False) - return (None, AST.Pool(poolType, + return (None, { curNode.getName() : AST.Pool(poolType, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), { AST.PaddingKeysDict.FH: kSizeH, @@ -522,7 +522,7 @@ def helper_processPool(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameTo AST.PaddingKeysDict.strideH: strideH, AST.PaddingKeysDict.strideW: strideW } - )) + )}) def MaxPool(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): return TFNodesAST.helper_processPool(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict, 'MAXPOOL') @@ -542,7 +542,7 @@ def ConcatV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr list(map(lambda x : AST.ID(dictNodeNameToOutVarStr[x]), inputsRef)), outputDiffInpDims=1 ) - return (None, retAST) + return (None, { curNode.getName() : retAST}) def ExpandDims(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -550,7 +550,7 @@ def ExpandDims(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarSt retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.ExpandDims.name, list(map(lambda x : AST.ID(dictNodeNameToOutVarStr[x]), inputsRef))) - return (None, retAST) + return (None, { curNode.getName() : retAST}) def Slice(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -562,14 +562,14 @@ def Slice(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : d AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), # begin idx AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]) # size ]) - return (None, retAST) + return (None, { curNode.getName() : retAST}) def Tile(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 2) - return (None, AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], + return (None, { curNode.getName() : AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.Tile.name, - [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])])) + [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])])}) def Sum(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -579,11 +579,11 @@ def Sum(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dic if ("keep_dims" in attrMapRef): keepdims = attrMapRef["keep_dims"].getB() curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] - return (None, AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + return (None, { curNode.getName() : AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), AST.Int(int(keepdims), 32, isSecret=False), curNodeShapeLi, - TFNodesAST.getOperatorsIdx('+'))) + TFNodesAST.getOperatorsIdx('+'))}) def Mean(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -593,30 +593,30 @@ def Mean(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : di if ("keep_dims" in attrMapRef): keepdims = attrMapRef["keep_dims"].getB() curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] - return (None, AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + return (None, { curNode.getName() : AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), AST.Int(int(keepdims), 32, isSecret=False), curNodeShapeLi, - TFNodesAST.getOperatorsIdx('mean'))) + TFNodesAST.getOperatorsIdx('mean'))}) def ArgMax(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 2) - return (None, AST.ArgMax(extraNodeInfoDict[curNode.getName()][0], + return (None, { curNode.getName() : AST.ArgMax(extraNodeInfoDict[curNode.getName()][0], AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), - extraNodeInfoDict[inputsRef[0]][0])) + extraNodeInfoDict[inputsRef[0]][0])}) def NoOp(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - return (None, None) + return (None, { curNode.getName() : None}) def Square(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 1) - return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), TFNodesAST.getOperatorsIdx('.*'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]) - )) + )}) def Pad(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): # Mode refers to 'CONSTANT', 'REFLECT' or 'SYMMETRIC' @@ -631,27 +631,27 @@ def Pad(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dic assert(mode == 0 and constant_values == 0) # For now to make life easy - deal with SYMMETRIC AND REFLECT when time comes inputsRef = curNode.getInputsRef() inputTensorShapeLi = extraNodeInfoDict[inputsRef[0]][0] - return (None, AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], + return (None, { curNode.getName() : AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.Pad.name, [ AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) ], outputDiffInpDims=1 - )) + )}) def FusedBatchNorm(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() - return (None, AST.FusedBatchNorm(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + return (None, { curNode.getName() : AST.FusedBatchNorm(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), - )) + )}) def FusedBatchNormV3(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() - return (None, AST.FusedBatchNorm(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + return (None, { curNode.getName() : AST.FusedBatchNorm(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), - )) + )}) def Transpose(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -663,7 +663,35 @@ def Transpose(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr permList = permTensor.getContentAsValArr() assert(permTensor.getDType().kind == "i") assert(permTensor.getShapeRef().getRank() == 1) - return (None, AST.Transpose(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), permList)) + return (None, { curNode.getName() : AST.Transpose(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), permList)}) + + def Split(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + assert(len(inputsRef) == 2) + axisNodeName = inputsRef[0] # split_dim input. Has to be a constant. We don't support dynamic codegen yet + axisNode = graph.__getitem__(axisNodeName) + axisTensor = axisNode.getAttrVal("value").getTensor() + axis = axisTensor.getConstantVal() + numSplits = curNode.getAttrVal("num_split").getI() + inputTensorShape = extraNodeInfoDict[inputsRef[1]][0] + assert(axis < len(inputTensorShape)) + assert(inputTensorShape[axis] % numSplits == 0) #Should perfectly split + sizeAlongSplitDim = int(inputTensorShape[axis]/numSplits) + outputAsts = {} + for i in range(0, numSplits): + output_name = curNode.getName() + if i != 0: + output_name += ":" + str(i) + subscriptRanges = [] + for j in range(0, len(inputTensorShape)): + start = 0 + end = inputTensorShape[j] - 1 + if j == axis: + start = i*sizeAlongSplitDim + end = start + sizeAlongSplitDim - 1 + subscriptRanges.append((start,end)) + outputAsts[output_name] = AST.Slice(AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), subscriptRanges) + return (None, outputAsts) def Squeeze(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -673,33 +701,33 @@ def Squeeze(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : squeezeDims = curNode.getAttrMapRef()["squeeze_dims"].getList().getILi() squeezeDimsRank = len(squeezeDims) - return (None, AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], + return (None, { curNode.getName() : AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.Squeeze.name, list(map(lambda x : AST.Int(x, 32, isSecret=False), squeezeDims)) + [ AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]) ] - )) + )}) def BiasAdd(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 2) - return (None, AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), TFNodesAST.getOperatorsIdx('+'), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - )) + )}) def ReadVariableOp(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() - return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])) + return (None, { curNode.getName() : AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])}) def Softmax(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() - return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])) + return (None, { curNode.getName() : AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])}) def StopGradient(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() - return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])) + return (None, { curNode.getName() : AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])}) def VarHandleOp(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): return TFNodesAST.VariableV2(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict) From cee44f6d66150057721eafb7d78148f721b840b1 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Thu, 26 Nov 2020 17:43:49 +0530 Subject: [PATCH 12/72] Support for reduced mean Adds support the reduce_mean operation in tensorflow. Consider the example: For inputs: Tensor of shape(s0,s1,s2,s3) reduction axes = [0,3] We generate the following program: If keep_dim == true output is of shape(1,s1,s2,1) else output is of shape(s1,s2) for i1=[0:s1] for i2=[0:s2] sum = 0 for i0=[0:s0] for i3=[0:s3] sum = sum + input[i0][i1][i2][i3] output[i1][i2] = sum / (s0 * s3) // keep_dim=false OR output[0][i1][i2][0] = sum / (s0 * s3) // keep_dim=true TODO: Also add support for reduced sum. --- Athos/SeeDot/AST/AST.py | 7 +- Athos/SeeDot/AST/ASTVisitor.py | 2 - Athos/SeeDot/AST/MtdAST.py | 1 - Athos/SeeDot/AST/PrintAST.py | 2 - Athos/SeeDot/IR/IRBuilderCSF.py | 150 +++++++++++++++++---- Athos/SeeDot/Optimizations/LivenessOpti.py | 2 +- Athos/TFCompiler/TFNodesAST.py | 26 +++- 7 files changed, 149 insertions(+), 41 deletions(-) diff --git a/Athos/SeeDot/AST/AST.py b/Athos/SeeDot/AST/AST.py index 1ecc85d7..ca747045 100644 --- a/Athos/SeeDot/AST/AST.py +++ b/Athos/SeeDot/AST/AST.py @@ -356,21 +356,20 @@ def __init__(self, outputShape: list, expr: ID, dim: ASTNode, inShape: list): self.inShape = inShape class Reduce(ASTNode): - def __init__(self, expr:ID, dim:ID, keepdims:Int, outShape:list, op: Operators): + def __init__(self, expr:ID, keepdims:bool, outShape:list, op: Operators, reductionAxesList: list): # keepdims is unused for now if assertInputTypes: assert isinstance(expr, ID) - assert isinstance(dim, ID) - assert isinstance(keepdims, Int) + assert isinstance(keepdims, bool) assert isinstance(outShape, list) for elem in outShape: assert isinstance(elem, int) assert isinstance(op, Operators) super().__init__() self.expr = expr - self.dim = dim self.keepdims = keepdims self.outShape = outShape self.op = op + self.reductionAxesList = reductionAxesList # shape : list of int, dataType : ID # NOTE: Though datatype is being passed to this function, the output code eventually only has diff --git a/Athos/SeeDot/AST/ASTVisitor.py b/Athos/SeeDot/AST/ASTVisitor.py index 1f4d6632..03f04ad3 100644 --- a/Athos/SeeDot/AST/ASTVisitor.py +++ b/Athos/SeeDot/AST/ASTVisitor.py @@ -76,8 +76,6 @@ def visitArgMax(self, node:AST.ArgMax, args=None): def visitReduce(self, node:AST.Reduce, args=None): self.visit(node.expr, args) - self.visit(node.dim, args) - self.visit(node.keepdims, args) def visitInput(self, node:AST.Input, args=None): pass diff --git a/Athos/SeeDot/AST/MtdAST.py b/Athos/SeeDot/AST/MtdAST.py index e9d4614e..ef9a4102 100644 --- a/Athos/SeeDot/AST/MtdAST.py +++ b/Athos/SeeDot/AST/MtdAST.py @@ -85,7 +85,6 @@ def visitArgMax(self, node:AST.ArgMax, mtd:dict): def visitReduce(self, node:AST.Reduce, mtd:dict): node.metadata.update(mtd) self.visit(node.expr, mtd) - self.visit(node.dim, mtd) def visitInput(self, node:AST.Input, mtd:dict): node.metadata.update(mtd) diff --git a/Athos/SeeDot/AST/PrintAST.py b/Athos/SeeDot/AST/PrintAST.py index 84025491..1ef915da 100644 --- a/Athos/SeeDot/AST/PrintAST.py +++ b/Athos/SeeDot/AST/PrintAST.py @@ -117,8 +117,6 @@ def visitArgMax(self, node:AST.ArgMax, args=None): def visitReduce(self, node:AST.Reduce, args=None): print(indent * node.depth, "reduce", AST.OperatorsSymbolDict[node.op.name], end=' ') self.visit(node.expr) - self.visit(node.dim) - self.visit(node.keepdims) def visitInput(self, node:AST.Input, args=None): print(indent * node.depth, "input( ", node.shape, node.dataType, " <", node.inputByParty.name, "> ", end='') diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index c315b291..ea9b8083 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -1116,37 +1116,139 @@ def visitInput(self, node:AST.Input, args=None): def visitReduce(self, node:AST.Reduce, args=None): (prog_1, expr1) = self.visit(node.expr) - (prog_2, expr2) = self.visit(node.dim) - - returnExpr = self.getTempVar() - assert(node.op in [AST.Operators.ADD, AST.Operators.Mean]) - if (node.op == AST.Operators.ADD): - funcName = "ReduceSum" - elif (node.op == AST.Operators.Mean): - funcName = "ReduceMean" - if not(Util.Config.disableTruncOpti): - self.scaleFacMapping[returnExpr.idf] = self.scaleFacMapping[expr1.idf] + # We already have the output shape so we dont need to calculate with keep_dims - funcArgsList = OrderedDict() - outputShape = node.type.shape - for ii, curDim in enumerate(outputShape): - funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii) + ''' + We need to reduce across axes. + Example: Say reduction axes are specified as 0,3 and keep dim = false + output rank -> len(input_shape) - len(reduction_axes) + output is 2D. + for i1=[0:s1] + for i2=[0:s2] + sum = 0 + for i0=[0:s0] + for i3=[0:s3] + sum = sum + input[i0][i1][i2][i3] + output[i1][i2] = sum / (s0 * s3) + if keep dim == true, output rank is same as input. We generate: + output[0][i1][i2][0] = sum / (s0 * s3) + + Ideally the above loop nest is what we would want to generate. But since we have + a division, we need to make calls to the div functionality and flatten the tensors. + temp_flat[s1*s2]; + out_flat[s1*s2]; + for i1=[0:s1] + for i2=[0:s2] + sum = 0 + for i0=[0:s0] + for i3=[0:s3] + sum = sum + input[i0][i1][i2][i3] + temp_flat[i1*s2 + i2] = sum + ElemWiseVectorPublicDiv(size=s1*s2, inp=temp_flat, divisor=s0*s3, out=out_flat) + for i1=[0:s1] + for i2=[0:s2] + output[i1][i2] = out_flat[i1*s2 + i2] + ''' + reduced_dims = node.reductionAxesList inputShape = node.expr.type.shape - for ii, curDim in enumerate(inputShape): - funcArgsList[IR.Int(curDim, 32)] = "InputShape_" + str(ii) + perm = [] + calculated_shape = [] + inputiters = self.getTempIterators(node.expr.type.dim) + outputiters = [] + no_elems = 1 + j = 0 + for i in range(len(inputShape)): + if i not in reduced_dims: + perm.append(i) + calculated_shape.append(inputShape[i]) + outputiters.append(inputiters[j]) + j = j + 1 + else: + no_elems = no_elems * inputShape[i] + if node.keepdims == 1: + calculated_shape.append(1) + outputiters.append(IR.Int(0,32)) + # perm will now be [ 1 ,2 ] + [ 0, 3] + perm.extend(reduced_dims) + loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))] + outputShape = node.type.shape + assert(calculated_shape == outputShape) + + sumExpr = self.getTempVar() + sumExpr_decl = IR.Decl(sumExpr.idf, Type.Int()) + initSumCmd = IR.Assn(sumExpr, IRUtil.zero) + updateSumCmd = IR.Assn(sumExpr, IRUtil.add(sumExpr, IRUtil.addIndex(expr1, inputiters))) + + outer_nesting = len(inputShape) - len(reduced_dims) + temp_flat = self.getTempVar() + temp_flat_decl = IR.Decl(temp_flat.idf, + Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint), + isSecret=node.type.isSecret) + # i1*s2 + i2 + flat_idx_expr = IRUtil.getFlatArrIdxExpr(inputiters[:outer_nesting], loop_shape[:outer_nesting]) + # temp_flat[i1*s2 + i2] = sum + temp_flat_expr = IRUtil.addIndex(temp_flat, [flat_idx_expr]) + updateOutCmd = IR.Assn(temp_flat_expr, sumExpr) + + # Generate the sum loop + inner_loops_processed = 0 + sum_loop = [updateSumCmd] + for i in reversed(range(len(loop_shape))): + sum_loop = [IR.For(inputiters[i], 0, sum_loop, 0, endInt=loop_shape[i])] + inner_loops_processed+=1 + if(inner_loops_processed == len(reduced_dims)): + sum_loop = [initSumCmd] + sum_loop + [updateOutCmd] + + # Insert call to ElemWiseVectorPublicDiv(size=s1*s2, inp=temp_flat, divisor=s0*s3, out=out_flat) + out_flat = self.getTempVar() + out_flat_decl = IR.Decl(out_flat.idf, + Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint), + isSecret=node.type.isSecret) + argsDict = OrderedDict() + argsDict[IR.Int(Util.get_volume(loop_shape[:outer_nesting]), 32)] = "size" + argsDict[temp_flat] = "input" + argsDict[IR.Int(Util.get_volume(loop_shape[outer_nesting:]), 32)] = "divisor" + argsDict[out_flat] = "output" + div_call = IR.FuncCall("ElemWiseVectorPublicDiv", argsDict) + + # Free temp_flat here + # Clear temp arrays + argsDict = OrderedDict() + argsDict[IR.Int(Util.get_volume(loop_shape[:outer_nesting]), 32)] = "size" + argsDict[temp_flat] = "A" + free_temp_flat_call = IR.FuncCall("ClearMemSecret1", argsDict) + + # Unflatten the output + output = self.getTempVar() + output_decl = IR.Decl(output.idf, node.type) + out_expr = IRUtil.addIndex(output, outputiters) + out_flat_expr = IRUtil.addIndex(out_flat, [flat_idx_expr]) + out_assn_expr = IR.Assn(out_expr, out_flat_expr) + unflatten_loop = IRUtil.loop(loop_shape[:outer_nesting], inputiters[:outer_nesting], [out_assn_expr]) + + # Free out_flat here + argsDict = OrderedDict() + argsDict[IR.Int(Util.get_volume(loop_shape[:outer_nesting]), 32)] = "size" + argsDict[out_flat] = "A" + free_out_flat_call = IR.FuncCall("ClearMemSecret1", argsDict) + + if not(Util.Config.disableTruncOpti): + self.scaleFacMapping[output.idf] = self.scaleFacMapping[expr1.idf] - funcArgsList[expr1] = "inputArr" - funcArgsList[expr2] = "dimension" - funcArgsList[returnExpr] = "outArr" - funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)) + self.varNameDelim + str(len(inputShape)), funcArgsList) comment = IR.Comment(str(node.metadata)) - prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, funcCall])) + final_prog = IRUtil.prog_merge( prog_1, + IR.Prog([comment]), + IR.Prog([sumExpr_decl, temp_flat_decl, out_flat_decl, output_decl]), + IR.Prog(sum_loop), + IR.Prog([div_call]), + IR.Prog([free_temp_flat_call]), + IR.Prog(unflatten_loop), + IR.Prog([free_out_flat_call])) - prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), prog_3) - return (prog_3, returnExpr) + return (final_prog, output) def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args=None): (prog1, expr1) = self.visit(node.expr) @@ -1175,7 +1277,7 @@ def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args=None): addExpr_sf = self.scaleFacMapping[expr3.idf] if (expr_sf > self.scaleFac): #Scale down needed - progExtraBefore = self.addTruncateFunctionCall(node.expr, "FusedBatchNorm", expr1, expr_sf - self.scaleFac) + progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr, "FusedBatchNorm", expr1, expr_sf - self.scaleFac)) self.scaleFacMapping[expr1.idf] = self.scaleFac if (multExpr_sf > self.scaleFac): diff --git a/Athos/SeeDot/Optimizations/LivenessOpti.py b/Athos/SeeDot/Optimizations/LivenessOpti.py index 161881f5..d69131f1 100644 --- a/Athos/SeeDot/Optimizations/LivenessOpti.py +++ b/Athos/SeeDot/Optimizations/LivenessOpti.py @@ -107,7 +107,7 @@ def visitArgMax(self, node:AST.ArgMax, args): return unboundVars def visitReduce(self, node:AST.Reduce, args): - unboundVars = list(set(self.visit(node.expr, args) + self.visit(node.dim, args) + self.visit(node.keepdims, args))) + unboundVars = list(set(self.visit(node.expr, args))) node.optidict[self.optidictKey] = unboundVars return unboundVars diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index 530f4775..1a8f1460 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -578,12 +578,18 @@ def Sum(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dic keepdims = False if ("keep_dims" in attrMapRef): keepdims = attrMapRef["keep_dims"].getB() + + reductionAxesNodeName = inputsRef[1] + redAxesN = graph.__getitem__(reductionAxesNodeName) + redAxesT = redAxesN.getAttrVal("value").getTensor() + reductionAxesList = redAxesT.getContentAsValArr() + curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] return (None, { curNode.getName() : AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), - AST.Int(int(keepdims), 32, isSecret=False), + keepdims, curNodeShapeLi, - TFNodesAST.getOperatorsIdx('+'))}) + TFNodesAST.getOperatorsIdx('+'), + reductionAxesList)}) def Mean(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() @@ -592,12 +598,18 @@ def Mean(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : di keepdims = False if ("keep_dims" in attrMapRef): keepdims = attrMapRef["keep_dims"].getB() + + reductionAxesNodeName = inputsRef[1] + redAxesN = graph.__getitem__(reductionAxesNodeName) + redAxesT = redAxesN.getAttrVal("value").getTensor() + reductionAxesList = redAxesT.getContentAsValArr() + curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] - return (None, { curNode.getName() : AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), - AST.Int(int(keepdims), 32, isSecret=False), + return (None, { curNode.getName() : AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + keepdims, curNodeShapeLi, - TFNodesAST.getOperatorsIdx('mean'))}) + TFNodesAST.getOperatorsIdx('mean'), + reductionAxesList)}) def ArgMax(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() From c449271d373b693fded6924eab27453350fd793c Mon Sep 17 00:00:00 2001 From: Bhatu Date: Thu, 26 Nov 2020 18:10:06 +0530 Subject: [PATCH 13/72] Add support for broadcasting semantics for binops. We add broadcasting support for add, sub, mul and equal. The broadcasting semantics are specified here https://numpy.org/doc/stable/user/basics.broadcasting.html Say we are given input A (4d array): 8 x 1 x 6 x 1 B (3d array): 7 x 1 x 5 We generate a loop with Result (4d array): 8 x 7 x 6 x 5 for i0=[0:8] for i1=[0:7] for i2=[0:6] for i3=[0:8] Result[i0][i1][i2][i3] = A[i0][0][i2][0] {+,*,-,==} B[i1][0][i3] --- Athos/SeeDot/IR/IRBuilderCSF.py | 236 ++++++++++++++++++++++++-------- Athos/SeeDot/IR/IRUtil.py | 55 ++++++++ Athos/SeeDot/Type.py | 52 ++----- Athos/SeeDot/Util.py | 61 +++++++++ 4 files changed, 303 insertions(+), 101 deletions(-) diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index ea9b8083..6f889ae1 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -425,19 +425,16 @@ def visitBopAddOrSubLike(self, node:AST.BOp, args=None): op = node.op if (op == AST.Operators.ADD): - (op_ir, op_fn) = (IR.Op.Op['+'], operator.add) - funcName = "MatAdd" + op_ir = IR.Op.Op['+'] elif (op == AST.Operators.SUB): - (op_ir, op_fn) = (IR.Op.Op['-'], operator.sub) - funcName = "MatSub" + op_ir = IR.Op.Op['-'] elif (op == AST.Operators.Equal): - (op_ir, op_fn) = (IR.Op.Op['=='], operator.eq) - funcName = "MatEqual" + op_ir = IR.Op.Op['=='] else: assert False - typ_3 = node.type - expr_3 = self.getTempVar() + node_type = node.type + out_arr = self.getTempVar() cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf) comment = IR.Comment(str(node.metadata)) @@ -468,43 +465,54 @@ def visitBopAddOrSubLike(self, node:AST.BOp, args=None): argsDict[exprToScale] = "exprToScale, arg#{0}".format(2 if (expr1_sf>expr2_sf) else 1) argsDict[IR.Int(scaleUpFactor, 32)] = "ScaleUpFactor" funcCall = IR.FuncCall(curFuncName, argsDict) - curProg = IR.Prog([comm,funcCall]) + + if Type.isInt(typeOfExprToScale) or typeOfExprToScale.shape == []: + assn_expr = IR.Assn(exprToScale, funcCall) + curProg = IR.Prog([comm,assn_expr]) + else: + curProg = IR.Prog([comm,funcCall]) prog_1 = IRUtil.prog_merge(curProg, prog_1) - self.scaleFacMapping[expr_3.idf] = self.scaleFacMapping[expr_1.idf] + self.scaleFacMapping[out_arr.idf] = self.scaleFacMapping[expr_1.idf] - if Type.isInt(typ_3): - decl = IR.Decl(expr_3.idf, typ_3, typ_3.bitlen, typ_3.isSecret) - assign = IR.Assn(expr_3, IR.IntBop(expr_1, op_ir, expr_2)) - prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, cmd0, decl, assign])) - else: - ## TODO - if (node.type.dim != node.expr1.type.dim): - # This needs broadcast of expr1 - assert False # For now this shouldn't occur - if (node.type.dim != node.expr2.type.dim): - # This needs broadcast of expr2 - funcName += 'BroadCast' - - outputShape = typ_3.shape - argsDict = OrderedDict() - inp1_shape = node.expr1.type.shape - inp2_shape = node.expr2.type.shape - for ii,curDimSize in enumerate(inp1_shape): - argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) - for ii,curDimSize in enumerate(inp2_shape): - argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) - for ii,curDimSize in enumerate(outputShape): - argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) - argsDict[expr_1] = "A" - argsDict[expr_2] = "B" - argsDict[expr_3] = "C" - funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)), argsDict) - prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, cmd0, funcCall])) - prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3) - - return (prog_3, expr_3) + decl = IR.Decl(out_arr.idf, node_type, node_type.bitlen, node_type.isSecret) + if Type.isInt(node_type): + assign = IR.Assn(out_arr, IR.IntBop(expr_1, op_ir, expr_2)) + out_prog = IR.Prog([assign]) + else: + outputShape = node_type.shape + inp1_shape = [] if Type.isInt(node.expr1.type) else node.expr1.type.shape + inp2_shape = [] if Type.isInt(node.expr2.type) else node.expr2.type.shape + + expected_output_shape, _, _ = Util.getBroadcastShapes(inp1_shape, inp2_shape) + assert(outputShape == expected_output_shape) + out_prog = IRUtil.generateBroadcastLoopBOp(expr_1, inp1_shape, expr_2, inp2_shape, out_arr, op_ir) + + out_prog = IRUtil.prog_merge(IR.Prog([comment, cmd0, decl]), out_prog) + out_prog = IRUtil.prog_merge(prog_1, prog_2, out_prog) + return (out_prog, out_arr) + + + # We first reshape both inputs and flatten them into 1d dims. + # For simplicity consider a non-broadcast example: + # inputs : inp1_arr[s1][s2], inp2_arr[s1][s2] + # after flattening : inp1_arr_flat[s1*s2], inp2_arr_flat[s1*s2] + # for i1=[0:s1] + # for i2=[0:s2] + # idx = i1*s2 + i2 + # inp1_arr_flat[idx] = inp1_arr[i1][i2] + # inp2_arr_flat[idx] = inp2_arr[i1][i2] + # If one input is from server and the other from model we can call an optimized version of mul + # ElemWiseActModelVectorMult(s1*s2, inp1_arr_flat, inp2_arr_flat, out_arr_flat) <- optimized + # OR + # ElemWiseSecretSharedVectorMult(s1*s2, inp1_arr_flat, inp2_arr_flat, out_arr_flat) + # Finally we reshape the flattened output + # for i1=[0:s1] + # for i2=[0:s2] + # idx = i1*s2 + i2 + # out_arr[i1][i2] = out_arr_flat[idx] + # Standard broadcast rules apply to generate these flattened tensors. def visitBopElemWiseOp(self, node:AST.BOp, args=None): (prog_1, expr_1) = self.visit(node.expr1) (prog_2, expr_2) = self.visit(node.expr2) @@ -515,38 +523,148 @@ def visitBopElemWiseOp(self, node:AST.BOp, args=None): elif (node.op == AST.Operators.ElemWiseDiv): op_ir = IR.Op.Op['./'] funcName = "ElemWiseDiv" + assert False, "Did not implement div yet" + else: + assert False, "Non mul/div elemwise op" - typ_3 = node.type - expr_3 = self.getTempVar() + comment = IR.Comment(str(node.metadata)) cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf) - outputShape = typ_3.shape - argsDict = OrderedDict() - for ii,curDimSize in enumerate(outputShape): - argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) - argsDict[expr_1] = "A" - argsDict[expr_2] = "B" - argsDict[expr_3] = "C" + + node_type = node.type + # outArr[s1][s2] + out_arr = self.getTempVar() + decl_out_arr = IR.Decl(out_arr.idf, node_type, node_type.bitlen, node_type.isSecret) + + if Type.isInt(node_type): + assign = IR.Assn(out_arr, IR.IntBop(expr_1, op_ir, expr_2)) + out_prog = IR.Prog([assign]) + else: + # Flattening inputs + output_shape = node_type.shape + inp1_shape = [] if Type.isInt(node.expr1.type) else node.expr1.type.shape + inp2_shape = [] if Type.isInt(node.expr2.type) else node.expr2.type.shape + out_iters = self.getTempIterators(len(output_shape)) + expected_output_shape, broadcast_mask_1, broadcast_mask_2 = Util.getBroadcastShapes(inp1_shape, inp2_shape) + assert(expected_output_shape == output_shape) + + # inp1_arr[i1][i2], inp2_arr[i1][i2], out_arr[i1][i2] + inp1_iters = IRUtil.getMaskedIters(broadcast_mask_1, out_iters, inp1_shape) + inp2_iters = IRUtil.getMaskedIters(broadcast_mask_2, out_iters, inp2_shape) + inp1_arr_expr = IRUtil.addIndex(expr_1, inp1_iters) + inp2_arr_expr = IRUtil.addIndex(expr_2, inp2_iters) + out_arr_expr = IRUtil.addIndex(out_arr, out_iters) + + flat_size = Util.get_volume(output_shape) + inp1_arr_flat = self.getTempVar() + inp2_arr_flat = self.getTempVar() + out_arr_flat = self.getTempVar() + flat_type = Type.Tensor([flat_size], node.expr1.type.bitlen, node.expr1.type.isSecret, node.expr1.type.taint) + # inp1_arr_flat[s1*s2] + # inp2_arr_flat[s1*s2] + # out_arr_flat[s1*s2] + decl_inp1_arr_flat = IR.Decl(inp1_arr_flat.idf, flat_type, node.expr1.type.bitlen, node.expr1.type.isSecret) + decl_inp2_arr_flat = IR.Decl(inp2_arr_flat.idf, flat_type, node.expr2.type.bitlen, node.expr2.type.isSecret) + decl_out_arr_flat = IR.Decl(out_arr_flat.idf, flat_type, node.type.bitlen, node.type.isSecret) + # idx + flat_idx = self.getTempVar() + decl_flat_idx = IR.Decl(flat_idx.idf, Type.Int(bitlen=32), bitlen=32, isSecret=False) + # For 4d, generate (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + (i4); + flat_idx_expr = IR.Int(0,32) + for i in range(len(out_iters)): + vol = Util.get_volume(output_shape[i+1:]) + flat_idx_expr = IRUtil.add(flat_idx_expr, IRUtil.mul(out_iters[i], IR.Int(vol,32))) + # inp1_arr_flat[idx], inp2_arr_flat[idx], out_arr_flat[idx] + inp1_arr_flat_expr = IRUtil.addIndex(inp1_arr_flat, [flat_idx]) + inp2_arr_flat_expr = IRUtil.addIndex(inp2_arr_flat, [flat_idx]) + out_arr_flat_expr = IRUtil.addIndex(out_arr_flat, [flat_idx]) + # idx = i1*s2 + i2; + # inp1_arr_flat[idx] = inp1_arr[i1][i2] + # inp2_arr_flat[idx] = inp2_arr[i1][i2] + assign_flat_idx_expr = IR.Assn(flat_idx, flat_idx_expr) + assign_inp1_arr_flat = IR.Assn(inp1_arr_flat_expr, inp1_arr_expr) + assign_inp2_arr_flat = IR.Assn(inp2_arr_flat_expr, inp2_arr_expr) + # Flattening loop + # for i1=[0:s1] + # for i2=[0:s2] + # idx = i1*s2 + i2 + # inp1_arr_flat[idx] = inp1_arr[i1][i2] + # inp2_arr_flat[idx] = inp2_arr[i1][i2] + out_loop = IRUtil.loop(output_shape, out_iters, [assign_flat_idx_expr, assign_inp1_arr_flat, assign_inp2_arr_flat]) + out_prog = IRUtil.Prog(out_loop) + decls = [decl_out_arr, decl_inp1_arr_flat, decl_inp2_arr_flat, decl_out_arr_flat, decl_flat_idx] + out_prog = IRUtil.prog_merge(IRUtil.Prog(decls), out_prog) + + # Insert call to mul/div functionality + argsDict = OrderedDict() + argsDict[IR.Int(flat_size, 32)] = "input_shape" + if (node.op == AST.Operators.ElemWiseDiv): + argsDict[inp1_arr_flat] = "A" + argsDict[inp2_arr_flat] = "B" + funcName = "ElemwiseSuperDuperSecretDiv" + assert False, "Elemwise div not implemented" + else: + # If either input is a model weight we can use an optimised version for mul + # Otherwise if both are derived from client input we use the hadmaard version + isMulOptimised = False + if not(self.isModel(node.expr1)) and not(self.isModel(node.expr2)): + argsDict[inp1_arr_flat] = "A" + argsDict[inp2_arr_flat] = "B" + else: + isMulOptimised = True + # Optimised version expects the second parameter to be an input from server + if self.isModel(node.expr2): + argsDict[inp1_arr_flat] = "A" + argsDict[inp2_arr_flat] = "B" + else: + # Shuffle the params. + argsDict[inp2_arr_flat] = "A" + argsDict[inp1_arr_flat] = "B" + funcName = "ElemWiseActModelVectorMult" if isMulOptimised else "ElemWiseSecretSharedVectorMult" + argsDict[out_arr_flat] = "Output" + funcCall = IR.FuncCall(funcName, argsDict) + out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) + + # Clear temp arrays + argsDict = OrderedDict() + argsDict[IR.Int(flat_size, 32)] = "size" + argsDict[inp1_arr_flat] = "A" + funcCall = IR.FuncCall("ClearMemSecret1", argsDict) + out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) + argsDict = OrderedDict() + argsDict[IR.Int(flat_size, 32)] = "size" + argsDict[inp2_arr_flat] = "A" + funcCall = IR.FuncCall("ClearMemSecret1", argsDict) + out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) + + # Unflatten output + assign_out_arr_flat = IR.Assn(out_arr_expr, out_arr_flat_expr) + out_loop = IRUtil.loop(output_shape, out_iters, [assign_flat_idx_expr, assign_out_arr_flat]) + out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog(out_loop)) + + argsDict = OrderedDict() + argsDict[IR.Int(flat_size, 32)] = "size" + argsDict[out_arr_flat] = "A" + funcCall = IR.FuncCall("ClearMemSecret1", argsDict) + out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) + progExtraBefore = IR.Prog([]) progExtraAfter = IR.Prog([]) if (Util.Config.disableTruncOpti): - progExtraAfter = self.addTruncateFunctionCall(node, funcName, expr_3, Util.Config.consSF) + progExtraAfter = self.addTruncateFunctionCall(node, "ElemWiseMul", out_arr, Util.Config.consSF) else: inputs_same = (expr_1.idf == expr_2.idf) expr1_sf = self.scaleFacMapping[expr_1.idf] expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): - progExtraBefore = self.addTruncateFunctionCall(node.expr1, funcName, expr_1, expr1_sf-self.scaleFac) + progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ElemWiseMul", expr_1, expr1_sf - self.scaleFac) self.scaleFacMapping[expr_1.idf] = self.scaleFac if (not inputs_same) and (expr2_sf > self.scaleFac): - progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, funcName, expr_2, expr2_sf-self.scaleFac)) + progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ElemWiseMul", expr_2, expr2_sf - self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac - self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac + self.scaleFacMapping[out_arr.idf] = 2*self.scaleFac - funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)), argsDict) - prog_3 = IRUtil.prog_merge(prog_1, prog_2, progExtraBefore, IR.Prog([cmd0, funcCall])) - prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3, progExtraAfter) - - return (prog_3, expr_3) + out_prog = IRUtil.prog_merge(IRUtil.Prog([comment, cmd0]), progExtraBefore, out_prog, progExtraAfter) + return (out_prog, out_arr) def visitBopMul(self, node:AST.BOp, args=None): typ_1 = node.expr1.type diff --git a/Athos/SeeDot/IR/IRUtil.py b/Athos/SeeDot/IR/IRUtil.py index 39636076..1c590a71 100644 --- a/Athos/SeeDot/IR/IRUtil.py +++ b/Athos/SeeDot/IR/IRUtil.py @@ -174,3 +174,58 @@ def print_loop(shape:list, iters:list, cmdl_body:CmdList, factor=0) -> CmdList: cmdl_for = [For(iters[i], 0, lt(iters[i], Int(shape[i])), cmdl_for, factor), Print(Var('""'))] return cmdl_for +# For tensor A of shape = 7 x 1 x 5 +# And out_iters = [i0, i1, i2, i3] +# Broadcast mask = [True, False, True, False] +# We generate iters = A[i1][0][i3] +# If input is scalar, broadcast_mask=[] and inp_shape=[] +def getMaskedIters(broadcast_mask: list, out_iters: list, inp_shape : list): + base_idx = len(out_iters) - len(inp_shape) + masked_iters = [] + for i in range(len(broadcast_mask)): + if broadcast_mask[i]: + masked_iters.append(Int(0,32)) + else: + masked_iters.append(out_iters[base_idx]) + base_idx +=1 + return masked_iters + +# Given input +# A (4d array): 8 x 1 x 6 x 1 +# B (3d array): 7 x 1 x 5 +# We generate a loop with +# Result (4d array): 8 x 7 x 6 x 5 +# for i0=[0:8] +# for i1=[0:7] +# for i2=[0:6] +# for i3=[0:8] +# Result[i0][i1][i2][i3] = A[i0][0][i2][0] + B[i1][0][i3] +def generateBroadcastLoopBOp(expr_1, inp1_shape: list, expr_2, inp2_shape : list, expr_out, op: Op.Op): + output_shape, broadcast_mask_1, broadcast_mask_2 = Util.getBroadcastShapes(inp1_shape, inp2_shape) + out_iters = [Var('i' + str(i)) for i in range(len(output_shape))] + inp1_iters = getMaskedIters(broadcast_mask_1, out_iters, inp1_shape) + inp2_iters = getMaskedIters(broadcast_mask_2, out_iters, inp2_shape) + + inp1_arr_expr = addIndex(expr_1, inp1_iters) + inp2_arr_expr = addIndex(expr_2, inp2_iters) + out_arr_expr = addIndex(expr_out, out_iters) + + assign_expr = Assn(out_arr_expr, IntBop(inp1_arr_expr, op, inp2_arr_expr)) + out_loop = loop(output_shape, out_iters, [assign_expr]) + out_prog = Prog(out_loop) + return out_prog + +# Generates the index into a flattened tensor. +# Example: +# for i1=[0:s1] +# for i2=[0:s2] +# for i3=[0:s3] +# for i4=[0:s4] +# generate (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + (i4); +def getFlatArrIdxExpr(iters:list, shape:list): + assert len(iters) == len(shape), "No. of loop idx vars should be equal to loop shapes" + flat_idx_expr = Int(0,32) + for i in range(len(iters)): + vol = get_volume(shape[i+1:]) + flat_idx_expr = add(flat_idx_expr, mul(iters[i], Int(vol,32))) + return flat_idx_expr \ No newline at end of file diff --git a/Athos/SeeDot/Type.py b/Athos/SeeDot/Type.py index 0420e751..4b6b5ec6 100644 --- a/Athos/SeeDot/Type.py +++ b/Athos/SeeDot/Type.py @@ -274,11 +274,9 @@ def visitBOp(self, node:AST.BOp, args=None): node.expr2.gamma = dict(node.gamma) fType = self.visit(node.expr2) - if node.op in [AST.Operators.ADD, AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv]: + if node.op in [AST.Operators.ADD, AST.Operators.SUB, AST.Operators.Equal, AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv]: # Ops supporting broadcasting return self.typeCheckBroadcastOps(node, eType, fType) - elif node.op in [AST.Operators.SUB, AST.Operators.Equal]: - return self.visitBopAddLike(node, eType, fType) elif node.op == AST.Operators.MUL: return self.visitBopMul(node, eType, fType) elif node.op == AST.Operators.CONV: @@ -293,35 +291,18 @@ def typeCheckBroadcastOps(self, node:AST.BOp, eType:Type, fType:Type): # If adding a new op here which supports broadcasting, then be careful! # Currently, its assumed the op is commutative. If that is not true, following will be wrong ! - assert node.op in [AST.Operators.ADD, AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv] - if (len(eType.shape) < len(fType.shape)): - # swap expr1 and expr2 -- this is valid for commutative ops - # be careful for ops which are not commutative - temp = node.expr1 - node.expr1 = node.expr2 - node.expr2 = temp - - temp = eType - eType = fType - fType = temp - - # Now true that dim(eType) >= dim(fTYpe) - assert len(eType.shape) >= len(fType.shape) - + assert node.op in [AST.Operators.ADD, AST.Operators.SUB, AST.Operators.Equal, AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv] if isInt(eType) and isInt(fType): node.type = Int(eType.bitlen) elif isTensor(eType) and isTensor(fType): - revETypeShape = eType.shape[::-1] - revFTypeShape = fType.shape[::-1] - for i, fTypeCurDim in enumerate(revFTypeShape): - eTypeCurDim = revETypeShape[i] - if not(eTypeCurDim==1 or fTypeCurDim==1 or eTypeCurDim==fTypeCurDim): - # broadcast not possible - raise error - print("Broadcast not possible for current node.", eType.shape, fType.shape) - assert False - - # Broadcast possible - node.type = copy.copy(eType) + output_shape, _, _ = Util.getBroadcastShapes(eType.shape, fType.shape) + node.type = Tensor(shape=output_shape, bitlen=eType.bitlen) + elif isTensor(eType) and isInt(fType): + output_shape, _, _ = Util.getBroadcastShapes(eType.shape, []) + node.type = Tensor(shape=output_shape, bitlen=eType.bitlen) + elif isInt(eType) and isTensor(fType): + output_shape, _, _ = Util.getBroadcastShapes([], fType.shape) + node.type = Tensor(shape=output_shape, bitlen=eType.bitlen) else: print(eType, fType) assert False @@ -444,19 +425,6 @@ def visitBopConvTranspose(self, node:AST.BOp, eType:Type, fType:Type, args=None) node.type = Tensor(shape, eType.bitlen, eType.isSecret | fType.isSecret, getTaint_type(eType, fType)) return node.type - def visitBopAddLike(self, node:AST.BOp, eType: Type, fType: Type, args=None): - if isInt(eType) and isInt(fType): - pass - elif isTensor(eType) and isTensor(fType): - assert eType.shape == fType.shape - else: - assert False - - node.type = copy.copy(eType) - node.type.taint = getTaint_type(eType, fType) - node.type.isSecret = eType.isSecret | fType.isSecret - return node.type - def visitFunc(self, node:AST.Func, args=None): node.expr.gamma = dict(node.gamma) eType = self.visit(node.expr) diff --git a/Athos/SeeDot/Util.py b/Athos/SeeDot/Util.py index 20ea4c28..319ee1f1 100644 --- a/Athos/SeeDot/Util.py +++ b/Athos/SeeDot/Util.py @@ -81,3 +81,64 @@ def write_debug_info(name_mapping): with open('debug/seedot_ezpc_name_map.txt', 'w') as f: for val in name_mapping: f.write(val + ' ' + name_mapping[val] + '\n') + +# Broadcasting Rules: +# A (4d array): 8 x 1 x 6 x 1 +# B (3d array): 7 x 1 x 5 +# Result (4d array): 8 x 7 x 6 x 5 +# Return Values +# Shape A broadcast mask: [False, True, False, True] +# Shape B broadcast mask: [True, False, True, False] +# Result shape: [8, 7, 6, 5] +# +# If input is a scalar, pass shape as [] +def getBroadcastShapes(Shape1 : list, Shape2 : list): + #Broadcast rules apply in reverse direction + shape1 = Shape1[::-1] + shape2 = Shape2[::-1] + len1 = len(shape1) + len2 = len(shape2) + outputshape = [] + swapped = False + if len1 != len2: + if len1 > len2: + len1, len2 = len2, len1 + shape1, shape2 = shape2, shape1 + swapped = True + assert len1 < len2 + + broadcastMask1 = [False] * len1 + broadcastMask2 = [False] * len2 + + for i in range(len2): + length = 0 + if i >= len1: + #broadcastMask1[i] = True + outputshape.append(shape2[i]) + continue + if shape1[i] != shape2[i]: + if shape1[i] == 1: + outputshape.append(shape2[i]) + broadcastMask1[i] = True + elif shape2[i] == 1: + outputshape.append(shape1[i]) + broadcastMask2[i] = True + else: + print("Dimension no. {} has a mismatch of length.".format(len2 - i)) + assert False, "Cannot broadcast. Program is malformed. Atleast one length should have been 1. i1: {} i2: {}".format(shape1[i], shape2[i]) + else: + outputshape.append(shape1[i]) + + if swapped: + broadcastMask1, broadcastMask2 = broadcastMask2, broadcastMask1 + + outputshape.reverse() + broadcastMask1.reverse() + broadcastMask2.reverse() + return outputshape, broadcastMask1, broadcastMask2 + +def get_volume(shape: list): + vol = 1 + for i in shape: + vol = vol * i + return vol \ No newline at end of file From 07901e61ab78b9c176b58c2c0c0764cf31140163 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Thu, 26 Nov 2020 18:43:38 +0530 Subject: [PATCH 14/72] Add a new garbage collector pass. For identity like ops (a=b), we sometimes run into use-after-free and double-free bugs. For this snippet J100 = J99 J101 = J99 + 3 <- last use of J99 J102 = J100 * 2 <- last use of J100 before we were doing: J100 = J99 J101 = J99 + 3 free(J99) J102 = J100 * 2 <- use-after-free free(J100) <- double-free now we do: J100 = J99 J101 = J99 + 3 J102 = J100 * 2 free(J100) Algorithm: We iterate through the program in reverse order and every time we see a use of a variable, we insert a free after it, unless we have already freed it before. When we check a variable has been freed, we also check whether any of its aliases have also been freed. For alias analysis, we maintain alias sets using disjoint sets. Whenever we encounter an a=b statement, we simply do a union of a and b sets. This replaces the old LivenessOpti pass. --- Athos/SeeDot/Compiler.py | 12 +- Athos/SeeDot/IR/IRBuilderCSF.py | 2 +- .../SeeDot/Optimizations/GarbageCollector.py | 266 ++++++++++++++++++ Athos/SeeDot/Optimizations/LivenessOpti.py | 155 ---------- Athos/SeeDot/Util.py | 81 +++++- 5 files changed, 354 insertions(+), 162 deletions(-) create mode 100644 Athos/SeeDot/Optimizations/GarbageCollector.py delete mode 100644 Athos/SeeDot/Optimizations/LivenessOpti.py diff --git a/Athos/SeeDot/Compiler.py b/Athos/SeeDot/Compiler.py index 2642f096..6cbc18ca 100644 --- a/Athos/SeeDot/Compiler.py +++ b/Athos/SeeDot/Compiler.py @@ -37,7 +37,8 @@ from IR.IRBuilderCSF import IRBuilderCSF from Codegen.EzPC import EzPC as EzPCCodegen import Optimizations.ReluMaxpoolOpti as ReluMaxpoolOpti -import Optimizations.LivenessOpti as LivenessOpti +import Optimizations.GarbageCollector as GarbageCollector +from collections import OrderedDict class Compiler: def __init__(self, version, target, sfType, astFile, printASTBool, consSF, bitlen, outputFileName, @@ -117,17 +118,18 @@ def run(self): print("Relu-maxpool optimization done.") if not(Util.Config.disableLivenessOpti): - print("Performing Liveness Optimization...") + print("Performing Garbage colelction...") mtdAST = MtdAST() - LivenessOpti.LivenessAnalysis().visit(ast) - LivenessOpti.LivenessOpti().visit(ast, [mtdAST, 0, {}]) - print("Liveness optimization done.") + GC = GarbageCollector.GarbageCollector(ast) + GC.run([mtdAST]) + print("Garbage collection done.") # Perform type inference and annotate nodes with type information InferType().visit(ast) if Util.Config.printASTBool: PrintAST().visit(ast) + print("\n") sys.stdout.flush() IRUtil.init() diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index 6f889ae1..e1fc8430 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -783,7 +783,7 @@ def visitBopMul2DTensor(self, node:AST.BOp, args=None): # and in inference, in every linear layer, either of A or B will be a model weight. # This is required because for some backends, knowing which of A or B is a model weight # can make a difference in their performance. - modelIsA = True + assert (self.isModel(node.expr1) or self.isModel(node.expr2)), "Expecting one of A or B to be an input by the server (model weight)." modelIsA = True if (not self.isModel(node.expr1)): diff --git a/Athos/SeeDot/Optimizations/GarbageCollector.py b/Athos/SeeDot/Optimizations/GarbageCollector.py new file mode 100644 index 00000000..f0b99d68 --- /dev/null +++ b/Athos/SeeDot/Optimizations/GarbageCollector.py @@ -0,0 +1,266 @@ +''' + +Authors: Pratik Bhatu + +Copyright: +Copyright (c) 2020 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' + +import AST.AST as AST +import Util +from AST.ASTVisitor import ASTVisitor +from AST.MtdAST import MtdAST + + +class SecretFlowAnalysis(ASTVisitor): + def __init__(self): + self.idf_to_secret = {} + self.node_to_secret = {} + + def isSecret(self, idf:str): + return self.idf_to_secret[idf] + + def visitInt(self, node:AST.Int, args): + self.node_to_secret[node] = node.isSecret + + def visitFloat(self, node:AST.Float, args): + self.node_to_secret[node] = node.isSecret + + def visitInput(self, node:AST.Input, args): + self.node_to_secret[node] = node.isSecret + + def visitId(self, node:AST.ID, args): + self.node_to_secret[node] = self.idf_to_secret[node.name] + + def visitLet(self, node:AST.Let, args): + self.visit(node.decl, args) + self.idf_to_secret[node.name.name] = self.node_to_secret[node.decl] + self.visit(node.expr, args) + + def visitDecl(self, node:AST.Decl, args): + self.node_to_secret[node] = node.isSecret + if node.valueList: + for elem in node.valueList: + self.visit(elem, args) + + def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args): + self.node_to_secret[node] = node.isSecret + for elem in node.argsList: + self.visit(elem, args) + + def visitTranspose(self, node:AST.Transpose, args): + self.visit(node.expr, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] + + def visitSlice(self, node:AST.Slice, args): + self.visit(node.expr, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] + + def visitReshape(self, node:AST.Reshape, args): + self.visit(node.expr, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] + + def visitPool(self, node:AST.Pool, args): + self.visit(node.expr, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] + + def visitUOp(self, node:AST.UOp, args): + self.visit(node.expr, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] + + def visitBOp(self, node:AST.BOp, args): + self.visit(node.expr1, args) + self.visit(node.expr2, args) + self.node_to_secret[node] = self.node_to_secret[node.expr1] | self.node_to_secret[node.expr1] + + def visitFunc(self, node:AST.Func, args): + self.visit(node.expr, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] + + + def visitArgMax(self, node:AST.ArgMax, args): + self.visit(node.expr, args) + self.visit(node.dim, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] | self.node_to_secret[node.dim] + + def visitReduce(self, node:AST.Reduce, args): + self.visit(node.expr, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] + + def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args): + self.visit(node.expr, args) + self.visit(node.multExpr, args) + self.visit(node.addExpr, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] | self.node_to_secret[node.multExpr] | self.node_to_secret[node.addExpr] + + +# A very basic alias analysis pass which creates alias sets for variables created +# through identity ops +# let a = b +class AliasAnalysis(ASTVisitor): + def __init__(self): + self.alias_sets = Util.DisjointSet() + super().__init__() + + def add_alias(self, inp1, inp2): + self.alias_sets.make_set(inp1) + self.alias_sets.make_set(inp2) + self.alias_sets.union(inp1, inp2) + + def get_alias_set(self, inp): + return self.alias_sets.get_key_set(inp) + + def visitLet(self, node:AST.Let, args): + self.visit(node.decl) + self.visit(node.expr) + + # Two IDs with same name can have diff pointers. Hence we store ID names instead of pointers. + if isinstance(node.decl, AST.ID): + self.add_alias(node.name.name, node.decl.name) + +''' + We visit the program bottom up. Every time we encounter a use of a variable, we insert + a free instruction after it, unless the variable has already been freed. + We are basically freeing variables after their last use. + + However, we also need to check for aliases of variables to avoid double frees and + use after free. + J100 = J99 + J101 = J99 + 3 <- last use of J99 + J102 = J100 * 2 <- last use of J100 + if we transform this to: + J100 = J99 + J101 = J99 + 3 + free(J99) + J102 = J100 * 2 <- use after free + free(J100) <- double free + instead we want to do: + J100 = J99 + J101 = J99 + 3 + J102 = J100 * 2 + free(J100) + .. + +''' +class GarbageCollector(ASTVisitor): + def __init__(self, ast): + self.ast = ast + self.secret_analysis = SecretFlowAnalysis() + self.secret_analysis.visit(self.ast) + self.alias_analysis = AliasAnalysis() + self.alias_analysis.visit(self.ast) + self.freed_nodes = set() + self.counter = 0 + super().__init__() + + def run(self, args): + self.visit(self.ast, args) + + def isVarFreed(self, inp): + alias_set = self.alias_analysis.get_alias_set(inp) + if alias_set is None: + return inp in self.freed_nodes + for i in alias_set: + if i in self.freed_nodes: + return True + return False + + def visitLet(self, node:AST.Let, args): + assert(isinstance(args, list)) + assert(isinstance(args[0], MtdAST)) + + self.visit(node.expr, args) + + usedVars = self.visit(node.decl, args) + if usedVars is None: + assert False, " visit of {} not implemented in GarbageCollector pass".format(str(type(node.decl))) + + varsToDeAllocate = [i for i in usedVars if not self.isVarFreed(i)] + self.freed_nodes = self.freed_nodes.union(set(varsToDeAllocate)) + + astSubTree = node.expr + mtdForNewASTNodes = {AST.ASTNode.mtdKeyTFOpName : "No-op: ClearMem", + AST.ASTNode.mtdKeyTFNodeName : ""} + for ii, curVarName in enumerate(varsToDeAllocate): + newSubTree = AST.Let(AST.ID("cv"+str(self.counter+ii)), + AST.Func(AST.Operators.ClearMemSecret if self.secret_analysis.isSecret(curVarName) else AST.Operators.ClearMemPublic, + AST.ID(curVarName)), + AST.ID("")) + self.counter += 1 + args[0].visit(newSubTree, mtdForNewASTNodes) + newSubTree.expr = astSubTree + node.expr = newSubTree + astSubTree = node.expr + + def visitInt(self, node:AST.Int, args): + return set() + + def visitFloat(self, node:AST.Float, args): + return set() + + def visitInput(self, node:AST.Input, args): + return set() + + def visitId(self, node:AST.ID, args): + return set([node.name]) + + def visitDecl(self, node:AST.Decl, args): + return set() + + def visitTranspose(self, node:AST.Transpose, args): + usedVars = self.visit(node.expr, args) + return usedVars + + def visitSlice(self, node:AST.Slice, args): + usedVars = self.visit(node.expr, args) + return usedVars + + def visitReshape(self, node:AST.Reshape, args): + usedVars = self.visit(node.expr, args) + return usedVars + + def visitPool(self, node:AST.Pool, args): + usedVars = self.visit(node.expr, args) + return usedVars + + def visitUOp(self, node:AST.UOp, args): + usedVars = self.visit(node.expr, args) + return usedVars + + def visitBOp(self, node:AST.BOp, args): + usedVars = self.visit(node.expr1, args) | self.visit(node.expr2, args) + return usedVars + + def visitFunc(self, node:AST.Func, args): + usedVars = self.visit(node.expr, args) + return usedVars + + def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args): + usedVars = set([]) + for elem in node.argsList: + usedVars |= self.visit(elem, args) + return usedVars + + def visitArgMax(self, node:AST.ArgMax, args): + usedVars = self.visit(node.expr, args) | self.visit(node.dim, args) + return usedVars + + def visitReduce(self, node:AST.Reduce, args): + usedVars = self.visit(node.expr, args) + return usedVars \ No newline at end of file diff --git a/Athos/SeeDot/Optimizations/LivenessOpti.py b/Athos/SeeDot/Optimizations/LivenessOpti.py deleted file mode 100644 index d69131f1..00000000 --- a/Athos/SeeDot/Optimizations/LivenessOpti.py +++ /dev/null @@ -1,155 +0,0 @@ -''' - -Authors: Nishant Kumar. - -Copyright: -Copyright (c) 2020 Microsoft Research -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -''' - -import AST.AST as AST -from AST.ASTVisitor import ASTVisitor -from AST.MtdAST import MtdAST - -#In the below analysis, each node saves what all unbound variables -# are used in its sub-tree. If the set is empty, nothing is saved. -# A subsequent pass then finds the variables -# wnich can be cleared. -class LivenessAnalysis(ASTVisitor): - optidictKey = "LivenessAnalysis" #This key will be used to store in optidict of the ASTNode - # list of all variables which are unbound in that sub-tree. - def visitInt(self, node:AST.Int, args): - return [] - - def visitFloat(self, node:AST.Float, args): - return [] - - def visitId(self, node:AST.ID, args): - unboundVars = [node.name] - node.optidict[self.optidictKey] = unboundVars - return unboundVars - - def visitDecl(self, node:AST.Decl, args): - return [] - - def visitTranspose(self, node:AST.Transpose, args): - unboundVars = self.visit(node.expr, args) - node.optidict[self.optidictKey] = unboundVars - return unboundVars - - def visitSlice(self, node:AST.Slice, args): - unboundVars = self.visit(node.expr, args) - node.optidict[self.optidictKey] = unboundVars - return unboundVars - - def visitReshape(self, node:AST.Reshape, args): - unboundVars = self.visit(node.expr, args) - node.optidict[self.optidictKey] = unboundVars - return unboundVars - - def visitPool(self, node:AST.Pool, args): - unboundVars = self.visit(node.expr, args) - node.optidict[self.optidictKey] = unboundVars - return unboundVars - - def visitUOp(self, node:AST.UOp, args): - unboundVars = self.visit(node.expr, args) - node.optidict[self.optidictKey] = unboundVars - return unboundVars - - def visitBOp(self, node:AST.BOp, args): - unboundVars = list(set(self.visit(node.expr1, args) + self.visit(node.expr2, args))) - node.optidict[self.optidictKey] = unboundVars - return unboundVars - - def visitFunc(self, node:AST.Func, args): - unboundVars = self.visit(node.expr, args) - node.optidict[self.optidictKey] = unboundVars - return unboundVars - - def visitLet(self, node:AST.Let, args): - declVars = self.visit(node.decl, args) - exprVars = self.visit(node.expr, args) - unboundVars = list((set(declVars)|set(exprVars))-set([node.name.name])) - if isinstance(node.decl, AST.ID): - #This is of the type let J1 = J2 in J1. - # Since J1 and J2 refer to the same variable, J2 should remain bounded. - unboundVars = list(set(unboundVars) - set([node.decl.name])) - node.optidict[self.optidictKey] = unboundVars - return unboundVars - - def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args): - unboundVarsSet = set([]) - for elem in node.argsList: - unboundVarsSet |= set(self.visit(elem, args)) - unboundVars = list(unboundVarsSet) - node.optidict[self.optidictKey] = unboundVars - return unboundVars - - def visitArgMax(self, node:AST.ArgMax, args): - unboundVars = list(set(self.visit(node.expr, args) + self.visit(node.dim, args))) - node.optidict[self.optidictKey] = unboundVars - return unboundVars - - def visitReduce(self, node:AST.Reduce, args): - unboundVars = list(set(self.visit(node.expr, args))) - node.optidict[self.optidictKey] = unboundVars - return unboundVars - - def visitInput(self, node:AST.Input, args): - return [] - - def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args): - unboundVars = list(set(self.visit(node.expr, args) + self.visit(node.multExpr, args) + self.visit(node.addExpr, args))) - node.optidict[self.optidictKey] = unboundVars - return unboundVars - -class LivenessOpti(ASTVisitor): - def visitLet(self, node:AST.Let, args): - assert(isinstance(args, list)) - assert(isinstance(args[0], MtdAST)) - assert(isinstance(args[1], int)) - assert(isinstance(args[2], dict)) #dict {variable name string -> isSecretVariable bool} - curUnboundVars = [] - exprUnboundVars = [] - if LivenessAnalysis.optidictKey in node.optidict: - curUnboundVars = node.optidict[LivenessAnalysis.optidictKey] - if LivenessAnalysis.optidictKey in node.expr.optidict: - exprUnboundVars = node.expr.optidict[LivenessAnalysis.optidictKey] - varsToDeAllocate = list(set(curUnboundVars)-set(exprUnboundVars)) - origNodeExpr = node.expr - astSubTree = node.expr - mtdForNewASTNodes = {AST.ASTNode.mtdKeyTFOpName : "No-op: ClearMem", - AST.ASTNode.mtdKeyTFNodeName : ""} - for ii, curVarName in enumerate(varsToDeAllocate): - assert(curVarName in args[2]) - newSubTree = AST.Let(AST.ID("cv"+str(args[1]+ii)), - AST.Func(AST.Operators.ClearMemSecret if args[2][curVarName] else AST.Operators.ClearMemPublic, - AST.ID(curVarName)), - AST.ID("")) - args[0].visit(newSubTree, mtdForNewASTNodes) - newSubTree.expr = astSubTree - node.expr = newSubTree - astSubTree = node.expr - self.visit(node.name, [args[0], args[1]+len(varsToDeAllocate), args[2]]) - self.visit(node.decl, [args[0], args[1]+len(varsToDeAllocate), args[2]]) - isCurrentLetDeclarationSecret = True - if hasattr(node.decl, 'isSecret'): - isCurrentLetDeclarationSecret = node.decl.isSecret - assert(type(isCurrentLetDeclarationSecret)==bool) - self.visit(origNodeExpr, [args[0], args[1]+len(varsToDeAllocate), {**args[2], **{node.name.name: isCurrentLetDeclarationSecret}}]) diff --git a/Athos/SeeDot/Util.py b/Athos/SeeDot/Util.py index 319ee1f1..0b729185 100644 --- a/Athos/SeeDot/Util.py +++ b/Athos/SeeDot/Util.py @@ -141,4 +141,83 @@ def get_volume(shape: list): vol = 1 for i in shape: vol = vol * i - return vol \ No newline at end of file + return vol + +class DisjointSet: + class Node: + def __init__(self): + self.parent = self + self.children = [] + + def get_root(self): + if (self.parent != self): + old_parent = self.parent + self.parent = self.parent.get_root() + if self.parent != old_parent: + self.parent.children.append(self) + old_parent.children.remove(self) + return self.parent + else: + return self + + def get_all_children(self): + all_children = [] + all_children.extend(self.children) + tmp = [] + for i in all_children: + tmp.extend(i.get_all_children()) + all_children.extend(tmp) + return all_children + + def __init__(self): + self.key_to_node = {} + self.node_to_key = {} + + def inSet(self, inp): + return inp in self.key_to_node + + def make_set(self, inp): + if self.inSet(inp): + return + n = self.Node() + self.key_to_node[inp] = n + self.node_to_key[n] = inp + + def union(self, inp1, inp2): + n1 = self.key_to_node[inp1] + n2 = self.key_to_node[inp2] + r1 = n1.get_root() + r2 = n2.get_root() + if (r1 != r2): + r1.parent = r2 + r2.children.append(r1) + + def find(self, inp): + if not self.inSet(inp): + return None + return self.key_to_node[inp].get_root() + + def find_key(self, inp): + node = self.find(inp) + if node is None: + return None + return self.node_to_key[node] + + def get_set(self, inp): + if not self.inSet(inp): + return None + n = self.key_to_node[inp].get_root() + return [n] + n.get_all_children() + + def get_key_set(self, inp): + nodes = self.get_set(inp) + if nodes is None: + return None + return [self.node_to_key[i] for i in nodes] + + def print(self): + print(self.key_to_node) + print(self.node_to_key) + + def print_set(self, inp): + print(self.get_key_set(inp)) \ No newline at end of file From 7a0f955ae4b03ae944c0af04d72e5181952b6646 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Thu, 26 Nov 2020 18:59:27 +0530 Subject: [PATCH 15/72] Fix tf size inference for multiple output tensors. There was an assumption that ops only have single tensor outputs. However ops like split return multiple tensors. This fixes that. --- Athos/TFCompiler/DumpTFMtData.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/Athos/TFCompiler/DumpTFMtData.py b/Athos/TFCompiler/DumpTFMtData.py index 8f49f04d..575a75f5 100644 --- a/Athos/TFCompiler/DumpTFMtData.py +++ b/Athos/TFCompiler/DumpTFMtData.py @@ -40,7 +40,7 @@ def get_optimized_graph_def(output_tensor): def save_graph_metadata(output_tensor, sess, feed_dict): #First save the graph def - graph_def = tf.get_default_graph().as_graph_def() + graph_def = sess.graph_def transforms = [ 'remove_nodes(op=Identity)', 'strip_unused_nodes', @@ -54,11 +54,17 @@ def save_graph_metadata(output_tensor, sess, feed_dict): # Save size information for tensors on which output depends tensors_to_evaluate = [] tensors_to_evaluate_names = [] - graph = tf.get_default_graph() + graph = sess.graph for node in optimized_graph_def.node: - cur_output = graph.get_operation_by_name(node.name).outputs[0] - tensors_to_evaluate.append(cur_output) - tensors_to_evaluate_names.append(node.name) + output_number = 0 + for cur_output in graph.get_operation_by_name(node.name).outputs: + tensors_to_evaluate.append(cur_output) + if output_number == 0: + tensor_name = node.name + else: + tensor_name = cur_output.name + tensors_to_evaluate_names.append(tensor_name) + output_number += 1 tensors_evaluated = sess.run(tensors_to_evaluate, feed_dict) tensors_shape = list(map(lambda x : x.shape, tensors_evaluated)) From 2dd0ce4d0beef1d88cb2a65285c6e692878ee725 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Thu, 26 Nov 2020 19:18:17 +0530 Subject: [PATCH 16/72] Improvements to compiler scripts. --- .../comparison_scripts/compare_output.sh | 12 ++-- .../comparison_scripts/convert_to_signed.py | 5 +- .../comparison_scripts/convert_to_signed.sh | 10 +++- Athos/CompilerScripts/compile_tf_graph.py | 49 ++++++++++----- .../preprocess_frozen_tf_graph.py | 59 ++++++++++++++----- Athos/CompilerScripts/tf_graph_trans.py | 13 ++-- 6 files changed, 105 insertions(+), 43 deletions(-) mode change 100755 => 100644 Athos/CompilerScripts/comparison_scripts/compare_output.sh mode change 100755 => 100644 Athos/CompilerScripts/comparison_scripts/convert_to_signed.sh diff --git a/Athos/CompilerScripts/comparison_scripts/compare_output.sh b/Athos/CompilerScripts/comparison_scripts/compare_output.sh old mode 100755 new mode 100644 index 1362f15c..9a745f14 --- a/Athos/CompilerScripts/comparison_scripts/compare_output.sh +++ b/Athos/CompilerScripts/comparison_scripts/compare_output.sh @@ -1,12 +1,12 @@ -# Usage: tf_output.float(floatingpt) party0_output(fixedpt) SCALING_FACTOR PRECISION(upto how many points to compare?) +# Usage: tf_output.float(floatingpt) party0_output(fixedpt) BITLEN SCALING_FACTOR PRECISION(upto how many points to compare?) # This first converts unsigned fixedpt to signed SCRIPT_DIR="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" -echo "Comparing output with tensorflow output upto $4 decimal points." -$SCRIPT_DIR/convert_to_signed.sh $2 +echo "Comparing output with tensorflow output upto $5 decimal points." +$SCRIPT_DIR/convert_to_signed.sh $2 $3 #Then runs the comparison script on it. -python3 $SCRIPT_DIR/compare_output.py $1 $2_signed $3 $4 +python3 $SCRIPT_DIR/compare_output.py $1 $2_signed $4 $5 if [ "$?" -eq 0 ]; then - echo "Output matches upto ${4} decimal points" + echo "Output matches upto ${5} decimal points" else - echo "Output does not match upto ${4} decimal points" + echo "Output does not match upto ${5} decimal points" fi diff --git a/Athos/CompilerScripts/comparison_scripts/convert_to_signed.py b/Athos/CompilerScripts/comparison_scripts/convert_to_signed.py index 0b6310c9..cd3524f8 100644 --- a/Athos/CompilerScripts/comparison_scripts/convert_to_signed.py +++ b/Athos/CompilerScripts/comparison_scripts/convert_to_signed.py @@ -1,13 +1,14 @@ import sys if __name__ == "__main__": - assert(len(sys.argv) == 3) + assert(len(sys.argv) == 4) inp_fname = sys.argv[1] out_fname = sys.argv[2] + bitlen = int(sys.argv[3]) f = open(inp_fname, 'r') op = [(int(line.rstrip())) for line in f] f.close() f = open(out_fname, 'w') for i in op: - f.write(str( i if (i<2**63) else i - 2**64) + '\n') + f.write(str( i if (i<2**(bitlen-1)) else i - 2**bitlen) + '\n') f.close() diff --git a/Athos/CompilerScripts/comparison_scripts/convert_to_signed.sh b/Athos/CompilerScripts/comparison_scripts/convert_to_signed.sh old mode 100755 new mode 100644 index 95a6493b..867c2b92 --- a/Athos/CompilerScripts/comparison_scripts/convert_to_signed.sh +++ b/Athos/CompilerScripts/comparison_scripts/convert_to_signed.sh @@ -1,9 +1,17 @@ SCRIPT_DIR="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" inp1=$1 +bitlen=$2 +if [ -z "$bitlen" ] +then + echo "Please pass bitlen." + + exit 1 +fi + temp_1=${inp1}_tmp_cmp awk '$0==($0+0)' $inp1 > $temp_1 -python3 ${SCRIPT_DIR}/convert_to_signed.py $temp_1 ${inp1}_signed +python3 ${SCRIPT_DIR}/convert_to_signed.py $temp_1 ${inp1}_signed $bitlen rm $temp_1 diff --git a/Athos/CompilerScripts/compile_tf_graph.py b/Athos/CompilerScripts/compile_tf_graph.py index b835866d..c3f3727a 100644 --- a/Athos/CompilerScripts/compile_tf_graph.py +++ b/Athos/CompilerScripts/compile_tf_graph.py @@ -30,24 +30,45 @@ def compile(model_fname, input_t_name, output_t_name, scaling_factor, save_weigh print("Loading processed tf graph ", model_fname) graph = load_pb(model_fname) - if not check_operation_exists(graph, input_t_name): - sys.exit(input_t_name + " input does not exist in the graph") if not check_operation_exists(graph, output_t_name): sys.exit(output_t_name + " output does not exist in the graph") - - input_t = graph.get_operation_by_name(input_t_name).outputs[0] output_t = graph.get_operation_by_name(output_t_name).outputs[0] + + if input_t_name != "": + if not check_operation_exists(graph, input_t_name): + sys.exit(input_t_name + " input does not exist in the graph") + + input_t = graph.get_operation_by_name(input_t_name).outputs[0] - # Generate random tensor as input - inp_shape = input_t.shape.as_list() - if None in inp_shape: - if input_shape == []: - sys.exit("Please supply shape for the input tensor as it is parametric (? dim) for this model. See --help.") + # Generate random tensor as input + # scalar input + if input_t.shape.dims == None: + inp_shape = [] else: - inp_shape = input_shape - rand_inp_t = np.zeros(inp_shape) + inp_shape = input_t.shape.as_list() + if None in inp_shape: + if input_shape == []: + sys.exit("Please supply shape for the input tensor as it is parametric (? dim) for this model. See --help.") + else: + inp_shape = input_shape + rand_inp_t = np.zeros(inp_shape) + + feed_dict = {input_t: rand_inp_t} + else: + # We can collect all placeholder nodes as inputs to the model + inputs = [i for i in graph.get_operations() if i.type=="Placeholder"] + feed_dict = {} + for op in inputs: + input_t = op.outputs[0] + if input_t.shape.dims == None: + inp_shape = [] + else: + inp_shape = input_t.shape.as_list() + if None in inp_shape: + sys.exit("Please supply input names and their shapes for the input tensor as it is parametric (? dim) for this model. See --help.") + rand_inp_t = np.zeros(inp_shape) + feed_dict[input_t] = rand_inp_t - feed_dict = {input_t: rand_inp_t} with graph.as_default(): with tf.Session() as sess: # Run initializers generated by preprocessing @@ -65,7 +86,7 @@ def compile(model_fname, input_t_name, output_t_name, scaling_factor, save_weigh if save_weights: DumpTFMtData.updateWeightsForBN(optimized_graph_def, sess) weights_fname = model_name[len("mpc_processed_"):] + '_input_weights_fixedpt_scale_' + str(scaling_factor) + '.inp' - print("Dumping model weights in ", weights_fname, ". These are to be used as input for party which owns the model") + print("\nDumping model weights in ", weights_fname, ". These are to be used as input for party which owns the model\n") DumpTFMtData.dumpTrainedWeightsInt(sess, trainVars, weights_fname, scaling_factor, 'w') def boolean_string(s): @@ -76,7 +97,7 @@ def boolean_string(s): def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--modelName", required=True, type=str, help="Name of processed tensorflow model (mpc_processed*.pb)") - parser.add_argument("--inputTensorName", required=True, type=str, help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)") + parser.add_argument("--inputTensorName", type=str, default='', help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)") parser.add_argument("--outputTensorName", required=True, type=str, help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)") parser.add_argument("--sf", default=12, type=int, help="scaling factor (int)") parser.add_argument("--saveWeights", type=boolean_string, default=False, help="Dump model weights in fixedpt {True/False}") diff --git a/Athos/CompilerScripts/preprocess_frozen_tf_graph.py b/Athos/CompilerScripts/preprocess_frozen_tf_graph.py index 835d9ba5..d73705ac 100644 --- a/Athos/CompilerScripts/preprocess_frozen_tf_graph.py +++ b/Athos/CompilerScripts/preprocess_frozen_tf_graph.py @@ -3,6 +3,10 @@ import sys import time import os + +import argparse +import os.path + # Transpose nodes require perm as compile time constants for parametric codegen # So we don't eliminate the constants we need dring compile time def get_const_names(graph): @@ -11,39 +15,66 @@ def get_const_names(graph): slice_begin_ops = set(i.inputs[1].op.name for i in graph.get_operations() if i.type == 'Slice') slice_size_ops = set(i.inputs[2].op.name for i in graph.get_operations() if i.type == 'Slice') mean_axes_ops = set(i.inputs[1].op.name for i in graph.get_operations() if i.type == 'Mean') - white_list = transp_perm_ops | padding_ops | slice_begin_ops | slice_size_ops | mean_axes_ops + split_dim_ops = set(i.inputs[0].op.name for i in graph.get_operations() if i.type == 'Split') + white_list = transp_perm_ops | padding_ops | slice_begin_ops | slice_size_ops | mean_axes_ops | split_dim_ops all_const_ops = set(i.name for i in graph.get_operations() if i.type == 'Const') return list(all_const_ops - white_list) -if __name__ == "__main__": - if len(sys.argv) != 2: - print("Usage: python preprocess_frozen_tf_graph.py tf_model_name.pb") - sys.exit() - else: - input_fname = sys.argv[1] +def check_operation_exists(graph, tensor_name): + op_list = [i.name for i in graph.get_operations()] + return tensor_name in op_list +def optimize(input_fname, output_t_name): + if not input_fname.endswith('.pb'): + sys.exit("Please supply a valid tensorflow protobuf model (.pb extension)") + actual_fname = os.path.basename(input_fname) dirname = os.path.dirname(input_fname) output_fname = os.path.join(dirname, "mpc_processed_" + actual_fname) print("Loading ", input_fname, "for processing.") - exec_graph = load_pb(input_fname) - + graph = load_pb(input_fname) + + if not check_operation_exists(graph, output_t_name): + sys.exit(output_t_name + " output does not exist in the graph") + input_names = [i.name for i in graph.get_operations() if i.type=="Placeholder"] + + #graph = remove_dead_nodes(graph, input_names, [output_t_name]) + print("\n\nThis process will take some time to run as we execute portions of the graph.\n\n") - time.sleep(5) + time.sleep(1) # Fold away all static computations + + print("Running fold splits") + graph = fold_splits(graph) + print(graph.get_operations(),end="\n\n") print("Running constant folding") - exec_graph = fold_splits(exec_graph) - exec_graph = fold_constants(exec_graph) + graph = fold_constants(graph) + # Convert constants to variables so as to separate the data and the generated code # Otherwise huge arrays will show up as constants in the generated code, thereby # increasing binary size. print("Convert frozen constants to variables") - exec_graph = convert_consts_to_var(exec_graph, get_const_names(exec_graph)) + graph = convert_consts_to_var(graph, get_const_names(graph)) + + input_names = [i.name for i in graph.get_operations() if i.type=="Placeholder"] + #graph = remove_dead_nodes(graph, input_names, [output_t_name]) + # At this stage the graph still has constants embedded in it # in the assign nodes for variables. We cannot execute the graph without # these constants. However after inferring the size, we can call remove_dead_nodes # to optimize away the constants and assign nodes and make the graph amenable # for codegen - dump_pb(exec_graph, output_fname) + dump_pb(graph, output_fname) print("The processed graph is dumped in ", output_fname) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--modelName", required=True, type=str, help="Name of tensorflow model (*.pb)") + parser.add_argument("--outputTensorName", required=True, type=str, help="Name of the output tensor for the model. (Op name, dont add '/:0' suffix)") + args = parser.parse_args() + return args + +if __name__ == '__main__': + args = parse_args() + optimize(args.modelName, args.outputTensorName) diff --git a/Athos/CompilerScripts/tf_graph_trans.py b/Athos/CompilerScripts/tf_graph_trans.py index 5877feeb..918006fe 100644 --- a/Athos/CompilerScripts/tf_graph_trans.py +++ b/Athos/CompilerScripts/tf_graph_trans.py @@ -16,26 +16,26 @@ def delete_nodes(graph, ops): tf.import_graph_def(new_gd, name="") return new_graph -def remove_dead_nodes(graph, input_tensors, output_tensors): +def remove_dead_nodes(graph, in_list, out_list): transforms = ['remove_nodes(op=Identity)', 'strip_unused_nodes'] - in_list = [i.name for i in input_tensors] - out_list = [i.name for i in output_tensors] optimized_graph_def = TransformGraph(graph.as_graph_def(), in_list, out_list, transforms) with tf.Graph().as_default() as opt_graph: tf.import_graph_def(optimized_graph_def, name="") return opt_graph def convert_consts_to_var(graph, const_names_list): + const_var_names_pairs = [] ops_to_delete = [] with graph.as_default(): + preexisting_vars = [tf.get_variable(i.name, i.outputs[0].shape) for i in graph.get_operations() if i.type=="VariableV2" or i.type=="Variable"] + var_list = [] for name in const_names_list: - #tensor = graph.get_tensor_by_name('{}:0'.format(name)) tensor = graph.get_operation_by_name(name).outputs[0] with tf.Session() as sess: t_value = sess.run(tensor) - t_name = '{}_const_var'.format(name) + t_name = '{}_mpc_const_var'.format(name) var = tf.Variable(t_value, name=t_name) const_var_names_pairs.append((name, t_name)) var_list.append(var) @@ -45,7 +45,8 @@ def convert_consts_to_var(graph, const_names_list): var_op = graph.get_operation_by_name('{}/read'.format(var_name)) ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_op)) ops_to_delete.append(const_op) - tf.compat.v1.variables_initializer(var_list, 'init_constvars') + + tf.compat.v1.variables_initializer(var_list + preexisting_vars, 'init_constvars') return delete_nodes(graph, ops_to_delete) def get_inputs(op): From 1a3b304c6da6a7686d51223f78810b19670ed833 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Mon, 21 Dec 2020 13:44:22 +0530 Subject: [PATCH 17/72] Add FusedBatchNorm in GarbageCollector and bugfix. Fix attribute parsing of strings. Don't do automatic scale down of argmax. --- Athos/SeeDot/Compiler.py | 4 +- Athos/SeeDot/IR/IRBuilderCSF.py | 40 +++++++++---------- .../SeeDot/Optimizations/GarbageCollector.py | 6 +++ Athos/TFCompiler/Graph.py | 2 +- 4 files changed, 29 insertions(+), 23 deletions(-) diff --git a/Athos/SeeDot/Compiler.py b/Athos/SeeDot/Compiler.py index 6cbc18ca..258c95e7 100644 --- a/Athos/SeeDot/Compiler.py +++ b/Athos/SeeDot/Compiler.py @@ -83,7 +83,7 @@ def fixOuputScale(self, res:(IR.Prog, IR.Expr), compiler:IRBuilderCSF): prog = res[0] expr = res[1] output_scale = compiler.scaleFacMapping[expr.idf] - if output_scale == Util.Config.consSF: + if output_scale == -1 or output_scale == Util.Config.consSF: return (prog, expr) elif output_scale > Util.Config.consSF: scale_down = output_scale - Util.Config.consSF @@ -105,7 +105,7 @@ def fixOuputScale(self, res:(IR.Prog, IR.Expr), compiler:IRBuilderCSF): prog = IRUtil.prog_merge(prog, new_prog) return (prog, expr) else: - assert False, "Scale up shouldnt be required of final output. We lost precision somewhere" + assert False, "Scale up shouldnt be required of final output {} -> {}. We lost precision somewhere".format(output_scale, Util.Config.consSF) def run(self): with open(Util.Config.astFile, 'rb') as ff: diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index e1fc8430..fb0ce8f1 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -814,8 +814,8 @@ def visitBopMul2DTensor(self, node:AST.BOp, args=None): return (prog_3, expr_3) def visitBopConv(self, node:AST.BOp, args=None): - (prog1, expr1) = self.visit(node.expr1) - (prog2, expr2) = self.visit(node.expr2) + (prog1, expr_1) = self.visit(node.expr1) + (prog2, expr_2) = self.visit(node.expr2) convDim = 2 if (AST.PaddingKeysDict.ConvDim in node.options): @@ -831,7 +831,7 @@ def visitBopConv(self, node:AST.BOp, args=None): assert(False) returnExpr = self.getTempVar() - comment = IR.Comment(expr1.idf + ' # ' + expr2.idf + ', convDim = ' + str(convDim)) + comment = IR.Comment(expr_1.idf + ' # ' + expr_2.idf + ', convDim = ' + str(convDim)) funcCallArgsDict = OrderedDict() funcCallArgsDict[IR.Int(N, 32)] = "N" if convDim == 3: @@ -861,8 +861,8 @@ def visitBopConv(self, node:AST.BOp, args=None): funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.group], 32)] = "G" isGroupConv = True - funcCallArgsDict[expr1] = "input" - funcCallArgsDict[expr2] = "filter" + funcCallArgsDict[expr_1] = "input" + funcCallArgsDict[expr_2] = "filter" if convDim == 3: funcCallArgsDict[IR.Int(Util.Config.consSF, 32)] = "consSF" funcCallArgsDict[returnExpr] = "output" @@ -886,13 +886,13 @@ def visitBopConv(self, node:AST.BOp, args=None): progExtraAfter = self.addTruncateFunctionCall(node, "Conv", returnExpr, Util.Config.consSF) else: inputs_same = (expr_1.idf == expr_2.idf) - expr1_sf = self.scaleFacMapping[expr1.idf] - expr2_sf = self.scaleFacMapping[expr2.idf] + expr1_sf = self.scaleFacMapping[expr_1.idf] + expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): - progExtraBefore = self.addTruncateFunctionCall(node.expr1, "Conv", expr1, expr1_sf-self.scaleFac) + progExtraBefore = self.addTruncateFunctionCall(node.expr1, "Conv", expr_1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr1.idf] = self.scaleFac if (not inputs_same) and (expr2_sf > self.scaleFac): - progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "Conv", expr2, expr2_sf-self.scaleFac)) + progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "Conv", expr_2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac self.scaleFacMapping[returnExpr.idf] = 2*self.scaleFac @@ -901,8 +901,8 @@ def visitBopConv(self, node:AST.BOp, args=None): return (returnProg, returnExpr) def visitBopConvTranspose(self, node:AST.BOp, args=None): - (prog1, expr1) = self.visit(node.expr1) - (prog2, expr2) = self.visit(node.expr2) + (prog1, expr_1) = self.visit(node.expr1) + (prog2, expr_2) = self.visit(node.expr2) convDim = 2 if (AST.PaddingKeysDict.ConvDim in node.options): @@ -942,7 +942,7 @@ def visitBopConvTranspose(self, node:AST.BOp, args=None): assert(AST.Operators.findConvOutputImgSize(d_prime_tilde, pad_d_tr_total, FD, stride_d_tr) == D) returnExpr = self.getTempVar() - comment = IR.Comment(expr1.idf + ' #T ' + expr2.idf + ', convDim = ' + str(convDim)) + comment = IR.Comment(expr_1.idf + ' #T ' + expr2.idf + ', convDim = ' + str(convDim)) funcCallArgsDict = OrderedDict() funcCallArgsDict[IR.Int(N, 32)] = "N" if convDim==3: @@ -971,8 +971,8 @@ def visitBopConvTranspose(self, node:AST.BOp, args=None): funcCallArgsDict[IR.Int(strideH, 32)] = "strideH" funcCallArgsDict[IR.Int(strideW, 32)] = "strideW" - funcCallArgsDict[expr1] = "input" - funcCallArgsDict[expr2] = "filter" + funcCallArgsDict[expr_1] = "input" + funcCallArgsDict[expr_2] = "filter" if convDim == 3: funcCallArgsDict[IR.Int(Util.Config.consSF, 32)] = "consSF" funcCallArgsDict[returnExpr] = "output" @@ -992,13 +992,13 @@ def visitBopConvTranspose(self, node:AST.BOp, args=None): progExtraAfter = self.addTruncateFunctionCall(node, "ConvTranspose", returnExpr, self.scaleFac) else: inputs_same = (expr_1.idf == expr_2.idf) - expr1_sf = self.scaleFacMapping[expr1.idf] - expr2_sf = self.scaleFacMapping[expr2.idf] + expr1_sf = self.scaleFacMapping[expr_1.idf] + expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): - progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ConvTranspose", expr1, expr1_sf-self.scaleFac) - self.scaleFacMapping[expr1.idf] = self.scaleFac + progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ConvTranspose", expr_1, expr1_sf-self.scaleFac) + self.scaleFacMapping[expr_1.idf] = self.scaleFac if (not inputs_same) and (expr2_sf > self.scaleFac): - progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ConvTranspose", expr2, expr2_sf-self.scaleFac)) + progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ConvTranspose", expr_2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr2.idf] = self.scaleFac self.scaleFacMapping[returnExpr.idf] = 2*self.scaleFac @@ -1216,7 +1216,7 @@ def visitArgMax(self, node:AST.ArgMax, args=None): funcArgsList[tmpExpr] = "outArr" if not(Util.Config.disableTruncOpti): - self.scaleFacMapping[tmpExpr.idf] = 0 #TODO -- is this the right thing to do? + self.scaleFacMapping[tmpExpr.idf] = -1 funcCall = IR.FuncCall("ArgMax" + self.varNameDelim + str(len(outputShape)), funcArgsList) comment = IR.Comment(str(node.metadata)) diff --git a/Athos/SeeDot/Optimizations/GarbageCollector.py b/Athos/SeeDot/Optimizations/GarbageCollector.py index f0b99d68..27a2dcaf 100644 --- a/Athos/SeeDot/Optimizations/GarbageCollector.py +++ b/Athos/SeeDot/Optimizations/GarbageCollector.py @@ -263,4 +263,10 @@ def visitArgMax(self, node:AST.ArgMax, args): def visitReduce(self, node:AST.Reduce, args): usedVars = self.visit(node.expr, args) + return usedVars + + def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args): + usedVars = self.visit(node.expr, args) + usedVars |= self.visit(node.multExpr, args) + usedVars |= self.visit(node.addExpr, args) return usedVars \ No newline at end of file diff --git a/Athos/TFCompiler/Graph.py b/Athos/TFCompiler/Graph.py index 7acbd1d3..4c6bed8d 100644 --- a/Athos/TFCompiler/Graph.py +++ b/Athos/TFCompiler/Graph.py @@ -404,7 +404,7 @@ def readFromFilePointer(self, fileP, cnt): return (True, cnt) elif (curToken == "s:"): if (errIfTokensNotMinLen(tokens, 2, cnt, "Value")): return (False, cnt) - self.__val = tokens[1] + self.__val = tokens[1][1:-1] elif (curToken == "i:"): if (errIfTokensNotMinLen(tokens, 2, cnt, "Value")): return (False, cnt) self.__val = int(tokens[1]) From db51c95f8a79c1a28c38335747671331d5763f42 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Tue, 22 Dec 2020 17:06:40 +0530 Subject: [PATCH 18/72] Convert Slice operator to compile-time codgen. --- Athos/TFCompiler/TFNodesAST.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index 1a8f1460..e1345b12 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -555,14 +555,21 @@ def ExpandDims(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarSt def Slice(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 3) - curNodeDataType = curNode.getAttrMapRef()["T"].getDataType() - retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], - TFNodesAST.UninterpFuncCallNames.CreateCopy.name, - [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), # of this - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), # begin idx - AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]) # size - ]) - return (None, { curNode.getName() : retAST}) + beginNode = graph.__getitem__(inputsRef[1]) + sizeNode = graph.__getitem__(inputsRef[2]) + assert beginNode.getAttrVal("value") is not None, "begin {} of Slice node {} has to be a constant".format(inputsRef[1], curNode.getName()) + assert sizeNode.getAttrVal("value") is not None, "size {} of Slice node {} has to be a constant".format(inputsRef[2], curNode.getName()) + begin = beginNode.getAttrVal("value").getTensor().getContentAsValArr() + size = sizeNode.getAttrVal("value").getTensor().getContentAsValArr() + assert begin is not None + assert size is not None + assert len(begin) == len(size) + subscriptRanges = [] + for i in range(0,len(size)): + subscriptRanges.append((begin[i], begin[i] + size[i] - 1)) + + return (None, { curNode.getName() : AST.Slice(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + subscriptRanges)}) def Tile(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() From 97fb6ce33b1ecda6916837e8587e6ca66a65fa67 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 23 Dec 2020 10:16:51 +0530 Subject: [PATCH 19/72] Fix typo for Conv2D truncation --- Athos/SeeDot/IR/IRBuilderCSF.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index fb0ce8f1..de0cf2f5 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -890,7 +890,7 @@ def visitBopConv(self, node:AST.BOp, args=None): expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, "Conv", expr_1, expr1_sf-self.scaleFac) - self.scaleFacMapping[expr1.idf] = self.scaleFac + self.scaleFacMapping[expr_1.idf] = self.scaleFac if (not inputs_same) and (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "Conv", expr_2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac From 378b347aecca83b48ee42f8b489fe4e766d15ddc Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 23 Dec 2020 10:18:18 +0530 Subject: [PATCH 20/72] Add support for ":0" indexing of nodes in graphdef Sometimes the generated graph defs specify the 0'th output explicitly. Example: node { name: "abc" input: "Placeholder_1:0" } node { name: "Placeholder_1" op: "Placeholder" } Node abc specifies that it wants the 0th output of the Placeholder_1 node. However in single output nodes, the tensor name is same as node name. So Placeholder_1:0 tensor cannot be found. The same is also true for the 0th output of multiple output nodes (output tensor names will be node_name,node_name:1,..) So while parsing the graph def we strip away any ":0" from the input names. --- Athos/TFCompiler/Graph.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/Athos/TFCompiler/Graph.py b/Athos/TFCompiler/Graph.py index 4c6bed8d..ec8440df 100644 --- a/Athos/TFCompiler/Graph.py +++ b/Athos/TFCompiler/Graph.py @@ -579,7 +579,12 @@ def readFromFilePointer(self, fileP, cnt): self.__op = tokens[1][1:-1] elif (curToken == "input:"): if (errIfTokensNotMinLen(tokens, 2, cnt, "node")): return (False, cnt) - self.__inputs.append(tokens[1][1:-1]) + input_name = tokens[1][1:-1] + # Sometimes graph defs generated specify 0'th output explicitly whereas the node names do not + # contain that. So we strip it + if input_name.endswith(":0"): + input_name = input_name[:-2] + self.__inputs.append(input_name) elif (curToken == "attr"): (noParseError, cnt) = self.readAttrFromFilePointer(fileP, cnt) if (not(noParseError)): From fef8760e1d5befbff5b732820a9ee0fbbad8ac28 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 23 Dec 2020 10:29:27 +0530 Subject: [PATCH 21/72] Improve assert message for broadcasts in typeinfer --- Athos/SeeDot/Type.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Athos/SeeDot/Type.py b/Athos/SeeDot/Type.py index 4b6b5ec6..33f880bb 100644 --- a/Athos/SeeDot/Type.py +++ b/Athos/SeeDot/Type.py @@ -222,7 +222,7 @@ def visitSlice(self, node:AST.Slice, args=None): assert(len(shape) == len(exprType.shape)) for i in range(0,len(shape)): - assert(shape[i] <= exprType.shape[i]) + assert shape[i] <= exprType.shape[i], " for {}".format(node.metadata) node.type = Tensor(shape, exprType.bitlen, exprType.isSecret, exprType.taint) return node.type From 895c57d2b0915fbec5230086dc967e6047e50850 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 23 Dec 2020 11:28:46 +0530 Subject: [PATCH 22/72] Add grappler opts and new script to compile graphs. Usage: python CompileTFGraph.py --config config.json where a sample config.json looks like this: { "model_name":"full_kernel.pb", "input_tensors":{ "actual_input_1":"2,245,234,3", "input2":"2,245,234,3" }, "output_tensors":[ "output1", "output2" ], "scale":10, "bitlength":63, "mode":"SCI", "save_weights" : true } Do python CompileTFGraph.py --help for seeing all the options. --- Athos/CompileTFGraph.py | 182 +++++++++++++++++++++ Athos/CompileTFGraph.sh | 206 ------------------------ Athos/CompilerScripts/compile_tf.py | 165 +++++++++++++++++++ Athos/CompilerScripts/grappler.py | 176 ++++++++++++++++++++ Athos/CompilerScripts/parse_config.py | 181 +++++++++++++++++++++ Athos/CompilerScripts/tf_graph_io.py | 12 ++ Athos/CompilerScripts/tf_graph_trans.py | 1 - Athos/TFCompiler/DumpTFMtData.py | 88 ++++++---- Athos/TFCompiler/ProcessTFGraph.py | 21 ++- 9 files changed, 786 insertions(+), 246 deletions(-) create mode 100644 Athos/CompileTFGraph.py delete mode 100755 Athos/CompileTFGraph.sh create mode 100644 Athos/CompilerScripts/compile_tf.py create mode 100644 Athos/CompilerScripts/grappler.py create mode 100644 Athos/CompilerScripts/parse_config.py diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py new file mode 100644 index 00000000..71d893ec --- /dev/null +++ b/Athos/CompileTFGraph.py @@ -0,0 +1,182 @@ +import argparse +from argparse import RawTextHelpFormatter + +import os +import os.path +import json +import sys + +import TFCompiler.ProcessTFGraph as Athos +import CompilerScripts.parse_config as parse_config +import CompilerScripts.compile_tf as compile_tf + + +def parse_args(): + parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter) + parser.add_argument( + "--config", + required=True, + type=str, + help="""Path to the config json file +Config file should be a json in the following format: +{ + // Mandatory options + + "model_name":"model.pb", // Tensorflow protobuf file to compile. + "output_tensors":[ + "output1", + "output2" + ], + "target":"PORTHOS2PC", // Compilation target. ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC + + // Optional options + "scale":10, // Scaling factor to compile for. Defaults to 12. + "bitlength":64, // Bit length to compile for. Defaults to 64. + "save_weights" : true, // Save model scaled weights in fixed point. Defaults to false. + + "input_tensors":{ // Name and shape of the input tensors + "actual_input_1":"224,244,3", // for the model. Not required if the + "input2":"2,245,234,3" // placeholder nodes have shape info. + }, + "modulo" : 32, // Modulo to be used for shares. Applicable for + // CPPRING/PORTHOS2PC backend. For + // PORTHOS2PC + backend=OT => Power of 2 + // PORTHOS2PC + backend=HE => Prime value." + + "backend" : "OT", // Backend to be used - OT/HE (default OT). + // Only applicable for PORTHOS2PC backend + + "disable_all_hlil_opts" : false, // Disable all optimizations in HLIL + "disable_relu_maxpool_opts" : false, // Disable Relu-Maxpool optimization + "disable_garbage_collection" : false, // Disable Garbage Collection optimization + "disable_trunc_opts" : false // Disable truncation placement optimization +} +""", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + params = parse_config.get_params(args.config) + # Mandatory + model_name = params["model_name"] + input_tensor_info = params["input_tensors"] + output_tensors = params["output_tensors"] + scale = 12 if params["scale"] is None else params["scale"] + bitlength = 64 if params["bitlength"] is None else params["bitlength"] + target = params["target"] + save_weights = params["save_weights"] + save_weights = False if save_weights is None else save_weights + + assert bitlength <= 64 and bitlength >= 1, "Bitlen must be >= 1 and <= 64" + assert target in [ + "PORTHOS", + "PORTHOS2PC", + "ABY", + "CPP", + "CPPRING", + ], "Target must be any of ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC" + + athos_dir = os.path.dirname(os.path.abspath(__file__)) + model_abs_path = os.path.abspath(model_name) + model_abs_dir = os.path.dirname(model_abs_path) + # Generate graphdef and sizeInfo metadata + compile_tf.compile( + model_name, input_tensor_info, output_tensors, scale, save_weights + ) + + # Compile to seedot. Generate AST in model directory + Athos.process_tf_graph(model_abs_path) + + # Compile to ezpc + model_base_name = model_name[:-3] + ezpc_file_name = "{mname}_{bl}_{target}.ezpc".format( + mname=model_base_name, bl=bitlength, target=target.lower() + ) + ezpc_abs_path = os.path.join(model_abs_dir, ezpc_file_name) + disable_all_hlil_opts = ( + False + if params["disable_all_hlil_opts"] is None + else params["disable_all_hlil_opts"] + ) + disable_relu_maxpool_opts = ( + False + if params["disable_relu_maxpool_opts"] is None + else params["disable_relu_maxpool_opts"] + ) + disable_garbage_collection = ( + False + if params["disable_garbage_collection"] is None + else params["disable_garbage_collection"] + ) + disable_trunc_opts = ( + False if params["disable_trunc_opts"] is None else params["disable_trunc_opts"] + ) + seedot_args = "" + seedot_args += "--astFile {}/astOutput.pkl --consSF {} ".format( + model_abs_dir, scale + ) + seedot_args += "--bitlen {} --outputFileName {} ".format(bitlength, ezpc_abs_path) + seedot_args += "--disableAllOpti {} ".format(disable_all_hlil_opts) + seedot_args += "--disableRMO {} ".format(disable_relu_maxpool_opts) + seedot_args += "--disableLivenessOpti {} ".format(disable_garbage_collection) + seedot_args += "--disableTruncOpti {} ".format(disable_trunc_opts) + + seedot_script = os.path.join(athos_dir, "SeeDot", "SeeDot.py") + print("python3 {} ".format(seedot_script) + seedot_args) + os.system("python3 {} ".format(seedot_script) + seedot_args) + + # Add library functions + if target in ["ABY", "CPPRING"]: + library = "cpp" + else: + library = target.lower() + + lib_bitlength = 64 if bitlength > 32 else 32 + library_dir = os.path.join(athos_dir, "TFEzPCLibrary") + common = os.path.join(library_dir, "Library{}_common.ezpc".format(lib_bitlength)) + if library == "cpp": + pre = os.path.join( + library_dir, "Library{}_{}_pre.ezpc".format(lib_bitlength, library) + ) + post = os.path.join( + library_dir, "Library{}_{}_post.ezpc".format(lib_bitlength, library) + ) + else: + pre = os.path.join( + library_dir, "Library{}_{}.ezpc".format(lib_bitlength, library) + ) + post = "" + temp = os.path.join(model_abs_dir, "temp.ezpc") + os.system( + "cat {pre} {common} {post} {ezpc}> {temp}".format( + pre=pre, common=common, post=post, ezpc=ezpc_abs_path, temp=temp + ) + ) + os.system("mv {temp} {ezpc}".format(temp=temp, ezpc=ezpc_abs_path)) + + modulo = params["modulo"] + backend = "OT" if params["backend"] is None else params["backend"] + ezpc_dir = os.path.join(athos_dir, "../EzPC/EzPC/") + os.system("cp {ezpc} {ezpc_dir}".format(ezpc=ezpc_abs_path, ezpc_dir=ezpc_dir)) + os.chdir(ezpc_dir) + ezpc_args = "" + ezpc_args += "--bitlen {bl} --codegen {target} --disable-tac".format( + bl=lib_bitlength, target=target + ) + output_name = ezpc_file_name[:-5] + "0.cpp" + if modulo is not None: + ezpc_args += "--modulo {} ".format(modulo) + if target == "PORTHOS2PC": + ezpc_args += "--backend {} ".format(backend.upper()) + output_name = ezpc_file_name[:-5] + "_{}0.cpp".format(backend.upper()) + if target in ["PORTHOS"]: + ezpc_args += "--sf {} ".format(scale) + os.system( + "eval `opam config env`; ./ezpc.sh {} ".format(ezpc_file_name) + ezpc_args + ) + os.system( + "cp {output} {model_dir} ".format(output=output_name, model_dir=model_abs_dir) + ) diff --git a/Athos/CompileTFGraph.sh b/Athos/CompileTFGraph.sh deleted file mode 100755 index 8422ebce..00000000 --- a/Athos/CompileTFGraph.sh +++ /dev/null @@ -1,206 +0,0 @@ -#!/bin/bash - -# Authors: Nishant Kumar, Pratik Bhatu. - -# Copyright: -# Copyright (c) 2020 Microsoft Research -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -########################################################################################## -# This is the CrypTFlow compilation script. -# Use this on a network to compile to MPC protocol. -# By default, this assumes there is a ezpc repo one level up - if you want to change it, -# please use Paths.config to override the default paths. -# Same goes for Porthos repository. -# NOTE : When overriding paths in Paths.config, assumption is there is no '/' at the end. -########################################################################################## - -# Load overriden paths from config file -. Paths.config - -echo -e "Loaded paths: EzPCDir - $EzPCDir, PorthosDir - $PorthosDir" - -usage() { - echo -e "CrypTFlow compilation script. Options:"; - echo -e "<-b|--bitlen> :: Bit length to compile for. Defaults to 64"; - echo -e "<-s|--scaling-fac> :: Scaling factor to compile for. Defaults to 12."; - echo -e "<-t|--target> :: Compilation target. Possible options: ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC. Defaults to CPP."; - echo -e "<-f|--filename> :: Tensorflow protobuf file to compile." - echo -e "<--modulo> :: Modulo to be used for shares. Applicable for CPPRING/PORTHOS2PC backend. For PORTHOS2PC, for backend type OT, this should be power of 2 and for backend type HE, this should be a prime." - echo -e "<--backend> :: Backend to be used - OT/HE (default OT). Applicable for PORTHOS2PC backend." - echo -e "<--disable-hlil-all-opti> :: Disable all optimizations in HLIL." - echo -e "<--disable-rmo> :: Disable Relu-Maxpool optimization." - echo -e "<--disable-liveness-opti> :: Disable Liveness Optimization." - echo -e "<--disable-trunc-opti> :: Disable truncation placement optimization." - echo -e "<-h|--help> :: help options."; - exit 1; -} - -BITLEN="64" -SCALINGFACTOR="12" -COMPILATIONTARGET="CPP" -EXECPYTHONARGS="" -while [[ $# -gt 0 ]] -do - key="$1" - case $key in - -b|--bitlen) - BITLEN="$2" - shift # past argument - shift # past value - ;; - -s|--scaling-fac) - SCALINGFACTOR="$2" - shift # past argument - shift # past value - ;; - -t|--target) - COMPILATIONTARGET="$2" - shift # past argument - shift # past value - ;; - -f|--filename) - FILENAME="$2" - shift - shift - ;; - --modulo) - MODULO="$2" - shift - shift - ;; - --backend) - BACKEND="$2" - shift - shift - ;; - -h|--help) - HELP=Y - shift # past one arg - ;; - --disable-hlil-all-opti) - DisableHLILAllOpti=Y - shift # past one arg - ;; - --disable-rmo) - DisableRMO=Y - shift # past one arg - ;; - --disable-liveness-opti) - DisableLivenessOpti=Y - shift # past one arg - ;; - --disable-trunc-opti) - DisableTruncOpti=Y - shift # past one arg - ;; - *) # unknown option - usage - ;; - esac -done - -if [ ! -z "$HELP" ] || [ -z "$FILENAME" ] ; then - usage -fi - -ACTUALBITLEN="${BITLEN}" -if [ "$ACTUALBITLEN" -gt 32 ]; then - BITLEN="64" -else - BITLEN="32" -fi - -compilationTargetLower=$(echo "$COMPILATIONTARGET" | awk '{print tolower($0)}') -compilationTargetHigher=$(echo "$COMPILATIONTARGET" | awk '{print toupper($0)}') -givenDirPath=$(dirname "$FILENAME") -fullDirPath=$(realpath "$givenDirPath") -porthosFullDirPath=$( realpath "$PorthosDir") -baseFileName=$(basename -- "$FILENAME") -extension="${baseFileName##*.}" -actualFileName="${baseFileName%.*}" #without extension -fullFilePath=$(realpath "$FILENAME") -ezpcOutputFileName=${actualFileName}'_'${BITLEN}'_'${compilationTargetLower} -ezpcOutputFullFileName=${fullDirPath}'/'${ezpcOutputFileName}'.ezpc' -finalCodeOutputFileName=${ezpcOutputFileName}'0.cpp' -if [ "$extension" != "pb" ]; then - echo -e "Error: Provide a tensorflow pb file to compile." - usage -fi -cd "$fullDirPath" - -cd - > /dev/null -cd ./TFCompiler -python3 ProcessTFGraph.py "$fullFilePath" -cd ../SeeDot -seedotArgs="--astFile ${fullDirPath}/astOutput.pkl --consSF ${SCALINGFACTOR} --bitlen ${ACTUALBITLEN} --outputFileName ${ezpcOutputFullFileName}" -#Temporarily always disable trunc optimization. TODO: Remove when fixed. -DisableTruncOpti=Y -if [ ! -z "$DisableHLILAllOpti" ]; then - seedotArgs="${seedotArgs} --disableAllOpti True" -fi -if [ ! -z "$DisableRMO" ]; then - seedotArgs="${seedotArgs} --disableRMO True" -fi -if [ ! -z "$DisableLivenessOpti" ]; then - seedotArgs="${seedotArgs} --disableLivenessOpti True" -fi -if [ ! -z "$DisableTruncOpti" ]; then - seedotArgs="${seedotArgs} --disableTruncOpti True" -fi -python3 SeeDot.py $seedotArgs -cd .. -libraryFile="$compilationTargetLower" -if [ "$compilationTargetLower" == "aby" ] || [ "$compilationTargetLower" == "cppring" ] ; then - libraryFile="cpp" -fi -if [ "$libraryFile" == "cpp" ];then - # CPP/ABY backend - cat "./TFEzPCLibrary/Library${BITLEN}_${libraryFile}_pre.ezpc" "./TFEzPCLibrary/Library${BITLEN}_common.ezpc" "./TFEzPCLibrary/Library${BITLEN}_${libraryFile}_post.ezpc" "$ezpcOutputFullFileName" > temp -else - cat "./TFEzPCLibrary/Library${BITLEN}_${libraryFile}.ezpc" "./TFEzPCLibrary/Library${BITLEN}_common.ezpc" "$ezpcOutputFullFileName" > temp -fi -mv temp "$ezpcOutputFullFileName" -cp "$ezpcOutputFullFileName" "$EzPCDir/EzPC" -cd "$EzPCDir/EzPC" -eval `opam config env` -ezpcArgs="--bitlen ${ACTUALBITLEN} --codegen ${compilationTargetHigher} --disable-tac" -if [ ! -z "$MODULO" ]; then - ezpcArgs="${ezpcArgs} --modulo ${MODULO}" -fi -if [ ! -z "$BACKEND" ]; then - backendUpper=$(echo "$BACKEND" | awk '{print toupper($0)}') - ezpcArgs="${ezpcArgs} --backend ${backendUpper}" - finalCodeOutputFileName=${ezpcOutputFileName}_${backendUpper}'0.cpp' -fi -if [ "$compilationTargetLower" == "porthos" ] ; then - ezpcArgs="${ezpcArgs} --sf ${SCALINGFACTOR}" -fi -./ezpc.sh "$ezpcOutputFullFileName" ${ezpcArgs} -if [ "$compilationTargetLower" == "cpp" ] || [ "$compilationTargetLower" == "cppring" ] ; then - cd "$fullDirPath" - g++ -O3 "$finalCodeOutputFileName" -o "$actualFileName.out" - echo -e "All compilation done." -else - cd - > /dev/null - echo -e "All compilation done." - if hash clang-format 2> /dev/null; then - clang-format -style=LLVM $fullDirPath/$finalCodeOutputFileName > tmp_clang - mv tmp_clang $fullDirPath/$finalCodeOutputFileName - fi -fi - diff --git a/Athos/CompilerScripts/compile_tf.py b/Athos/CompilerScripts/compile_tf.py new file mode 100644 index 00000000..fa9accc8 --- /dev/null +++ b/Athos/CompilerScripts/compile_tf.py @@ -0,0 +1,165 @@ +import argparse +import os.path +import json +import sys + +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +import tensorflow as tf +import numpy as np + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +import parse_config +import tf_graph_io +import grappler + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "TFCompiler")) +import DumpTFMtData + + +def get_graph_from(graph_def): + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name="") + return graph + + +def check_operation_exists(graph, op): + op_list = [i.name for i in graph.get_operations()] + return op in op_list + + +def tensors_exist(graph, tensor_names): + op_list = [i.name for i in graph.get_operations()] + for i in tensor_names: + assert i in op_list, "input " + i + " does not exist in the graph" + return True + + +def set_input_shapes(graph, input_t_info): + tensor_names = input_t_info.keys() + assert tensors_exist(graph, tensor_names) + + graph_def = graph.as_graph_def() + inputs = [i for i in graph.get_operations() if i.type == "Placeholder"] + + input_map = {} + with tf.Graph().as_default() as new_graph: + for i in inputs: + if i.name not in input_t_info: + continue + shape = input_t_info[i.name] + input_map[i.name] = tf.compat.v1.placeholder( + i.get_attr("dtype"), shape=shape, name=i.name + ) + tf.import_graph_def(graph_def, input_map=input_map, name="") + return new_graph + + +def get_tensor(graph, name): + return graph.get_operation_by_name(name).outputs[0] + + +def infer_input_info(graph): + input_t_info = {} + inputs = [i for i in graph.get_operations() if i.type == "Placeholder"] + for i in inputs: + input_t = i.outputs[0] + if input_t.shape.dims == None: + inp_shape = [] + else: + inp_shape = input_t.shape.as_list() + assert None not in inp_shape, "Placeholder node " + i.name + "has unknown" + +" shape. Please specify name and shape in config" + input_t_info[i.name] = inp_shape + return input_t_info + + +# Generates the computation graph and tensor size metadata and saves them in +# the model directory. +# Optionaly dumps model weights as fixedpt in specified scaling factor +def compile(model_fname, input_t_info, output_t_names, scaling_factor, save_weights): + model_name = os.path.basename(model_fname)[:-3] + print("Loading tf graph ", model_fname) + graph = tf_graph_io.load_pb(model_fname) + assert tensors_exist(graph, output_t_names) + + if input_t_info == {}: + input_t_info = infer_input_info(graph) + else: + tensors_exist(graph, list(input_t_info.keys())) + graph = set_input_shapes(graph, input_t_info) + input_t_names = list(input_t_info.keys()) + graph_def = grappler.optimize(graph, input_t_names, output_t_names) + graph_def = grappler.convert_consts_to_var(graph_def) + graph = get_graph_from(graph_def) + + feed_dict = {} + for name, shape in input_t_info.items(): + tensor = get_tensor(graph, name) + zeros = np.zeros(shape) + feed_dict[tensor] = zeros + + cwd = os.getcwd() + with graph.as_default(): + with tf.compat.v1.Session() as sess: + # Run initializers generated by preprocessing + if check_operation_exists(graph, "init_constvars"): + sess.run(graph.get_operation_by_name("init_constvars")) + sess.run(tf.compat.v1.global_variables_initializer()) + model_dir = os.path.realpath(os.path.dirname(model_fname)) + os.chdir(model_dir) + + # At this stage the graph still has constants embedded in it + # in the assign nodes for variables. We cannot execute the graph without + # these constants. We strip them away in a new graph def which is amenable + # to codegen but leave them in the graph. + optimized_graph_def = DumpTFMtData.strip_variable_init_constants( + graph_def, input_t_names, output_t_names + ) + + tf_graph_io.dump_graph_def_pb( + optimized_graph_def, "optimised_" + model_fname + ) + DumpTFMtData.save_graphdef(optimized_graph_def) + DumpTFMtData.save_sizeinfo(optimized_graph_def, sess, feed_dict) + print("Model compilation done.") + if save_weights: + weights_fname = ( + model_name + + "_input_weights_fixedpt_scale_" + + str(scaling_factor) + + ".inp" + ) + print( + "\nDumping model weights in ", + model_dir + "/" + weights_fname, + ".\nThese are to be used as input for party which owns the model\n", + ) + DumpTFMtData.save_weights( + optimized_graph_def, sess, feed_dict, weights_fname, scaling_factor + ) + os.chdir(cwd) + return + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", required=True, type=str, help="Path to the config file" + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + params = parse_config.get_params(args.config) + compile( + params["model_name"], + params["input_tensors"], + params["output_tensors"], + params["scale"], + params["save_weights"], + ) diff --git a/Athos/CompilerScripts/grappler.py b/Athos/CompilerScripts/grappler.py new file mode 100644 index 00000000..3edfe6be --- /dev/null +++ b/Athos/CompilerScripts/grappler.py @@ -0,0 +1,176 @@ +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +import tensorflow as tf +import tensorflow.contrib.graph_editor as ge +from tensorflow.python.platform import gfile +from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.grappler import cluster +from tensorflow.compat.v1 import GraphKeys +from tensorflow.core.protobuf.meta_graph_pb2 import SignatureDef +from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig + + +def get_graph_from(graph_def): + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name="") + return graph + + +def get_default_config(): + c = tf_optimizer.config_pb2.ConfigProto() + optimizer_opts = c.graph_options.rewrite_options + OFF = RewriterConfig.Toggle.Value("OFF") + optimizer_opts.layout_optimizer = OFF + optimizer_opts.implementation_selector = OFF + optimizer_opts.min_graph_nodes = -1 + optimizer_opts.meta_optimizer_iterations = 2 + optimizer_opts.memory_optimization = RewriterConfig.MemOptType.Value("NO_MEM_OPT") + return c + + +def get_only_prune_config(): + c = get_default_config() + optimizer_opts = c.graph_options.rewrite_options + OFF = RewriterConfig.Toggle.Value("OFF") + optimizer_opts.constant_folding = OFF + optimizer_opts.shape_optimization = OFF + optimizer_opts.remapping = OFF + optimizer_opts.arithmetic_optimization = OFF + optimizer_opts.dependency_optimization = OFF + optimizer_opts.loop_optimization = OFF + optimizer_opts.function_optimization = OFF + optimizer_opts.meta_optimizer_iterations = 1 + return c + + +def get_white_list(graph): + transp_perm_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "Transpose" + ) + padding_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "Pad" + ) + slice_begin_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "Slice" + ) + slice_size_ops = set( + i.inputs[2].op.name for i in graph.get_operations() if i.type == "Slice" + ) + mean_axes_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "Mean" + ) + split_dim_ops = set( + i.inputs[0].op.name for i in graph.get_operations() if i.type == "Split" + ) + concat_axes_ops = set( + i.inputs[2].op.name + for i in graph.get_operations() + if i.type == "ConcatV2" or i.type == "Concat" + ) + white_list = ( + transp_perm_ops + | padding_ops + | slice_begin_ops + | slice_size_ops + | mean_axes_ops + | split_dim_ops + | concat_axes_ops + ) + return list(white_list) + + +def optimize(g, inputs, outputs): + sd = SignatureDef() + for name in inputs: + input_t = g.get_operation_by_name(name).outputs[0] + sd.inputs[name].name = name + sd.inputs[name].dtype = input_t.dtype.as_datatype_enum + sd.inputs[name].tensor_shape.CopyFrom(input_t.shape.as_proto()) + for name in outputs: + output_t = g.get_operation_by_name(name).outputs[0] + sd.outputs[name].name = name + sd.outputs[name].dtype = output_t.dtype.as_datatype_enum + sd.outputs[name].tensor_shape.CopyFrom(output_t.shape.as_proto()) + + tf.compat.v1.enable_resource_variables() + cl = cluster.Cluster(disable_detailed_stats=True) + + # We have to run this twice to eliminate constants that are left after + # optimising away split/pad/transpose nodes. They are const parameters like + # axis, perm. They remain after 1 iter of optimization because we specify them + # in the whitelist + for i in range(2): + if i == 0: + graph = g + c = get_default_config() + else: + graph = get_graph_from(optimized_graph_def) + c = get_only_prune_config() + + white_list = get_white_list(graph) + for name in white_list: + graph.add_to_collection( + GraphKeys.TRAIN_OP, graph.get_operation_by_name(name) + ) + + meta_graph = tf.compat.v1.train.export_meta_graph( + graph_def=graph.as_graph_def(), graph=graph + ) + meta_graph.signature_def["not_used_key"].CopyFrom(sd) + + optimized_graph_def = tf_optimizer.OptimizeGraph( + config_proto=c, metagraph=meta_graph, cluster=cl + ) + # Don't create VarHandleOp, ReadVariableOp, VarIsInitializedOp + # Instead create VariableV2 ops in the future + tf.disable_resource_variables() + return optimized_graph_def + + +def delete_nodes(gd, ops): + nodes_to_delete = set(op.name for op in ops) + new_gd = tf.compat.v1.GraphDef() + nodes_to_keep = [] + for n in gd.node: + if not n.name in nodes_to_delete: + nodes_to_keep.append(n) + new_gd.node.extend(nodes_to_keep) + return new_gd + + +def convert_consts_to_var(graph_def): + graph = get_graph_from(graph_def) + all_const_ops = set(i.name for i in graph.get_operations() if i.type == "Const") + const_names_list = list(all_const_ops - set(get_white_list(graph))) + const_var_names_pairs = [] + ops_to_delete = [] + with graph.as_default(): + preexisting_vars = [ + tf.get_variable(i.name, i.outputs[0].shape) + for i in graph.get_operations() + if i.type == "VariableV2" or i.type == "Variable" + ] + + var_list = [] + for name in const_names_list: + tensor = graph.get_operation_by_name(name).outputs[0] + with tf.compat.v1.Session() as sess: + t_value = sess.run(tensor) + t_name = "{}_mpc_const_var".format(name) + var = tf.compat.v1.Variable(t_value, name=t_name) + var_read_op_name = var.to_proto().snapshot_name[:-2] + const_var_names_pairs.append((name, var_read_op_name)) + var_list.append(var) + + for const_name, var_read_name in const_var_names_pairs: + const_op = graph.get_operation_by_name(const_name) + var_op = graph.get_operation_by_name(var_read_name) + ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_op)) + ops_to_delete.append(const_op) + + tf.compat.v1.variables_initializer( + var_list + preexisting_vars, "init_constvars" + ) + return delete_nodes(graph.as_graph_def(), ops_to_delete) diff --git a/Athos/CompilerScripts/parse_config.py b/Athos/CompilerScripts/parse_config.py new file mode 100644 index 00000000..64edfd35 --- /dev/null +++ b/Athos/CompilerScripts/parse_config.py @@ -0,0 +1,181 @@ +import argparse +import os.path +import json +import sys + +""" +Sample config: +{ +// Mandatory + "model_name":"model.pb", + "output_tensors":[ + "output1", + "output2" + ], + "scale":10, + "bitlength":63, + "target":"SCI", // ABY/CPP/CPPRING/PORTHOS/SCI + "save_weights" : true +// Optional + "input_tensors":{ + "actual_input_1":"2,245,234,3", + "input2":"2,245,234,3" + }, + "modulo" + "backend" + "disable_all_hlil_opts" + "disable_relu_maxpool_opts" + "disable_garbage_collection" + "disable_trunc_opts" +} +""" + + +def get_config(config_path): + if not os.path.isfile(config_path): + sys.exit("Config file specified does not exist") + with open(config_path) as f: + try: + config = json.load(f) + except JSONDecodeError as e: + sys.exit( + "Error while parsing the config json:\n" + + e.msg + + " at line no. " + + str(e.lineno) + ) + return config + + +def get_str_param(config, p_name): + p = config.get(p_name) + if p is None: + sys.exit(p_name + " not specified in config.") + assert type(p) == str, p_name + " is not a string" + return p + + +def get_opt_str_param(config, p_name): + p = config.get(p_name) + if p is None: + return p + assert type(p) == str, p_name + " is not a string" + return p + + +def get_bool_param(config, p_name): + p = config.get(p_name) + if p is None: + sys.exit(p_name + " not specified in config.") + assert type(p) == bool, p_name + " is not a boolean" + return p + + +def get_opt_bool_param(config, p_name): + p = config.get(p_name) + if p is None: + return p + assert type(p) == bool, p_name + " is not a boolean" + return p + + +def get_int_param(config, p_name): + p = config.get(p_name) + if p is None: + sys.exit(p_name + " not specified in config.") + assert type(p) == int, p_name + " is not an integer" + return p + + +def get_opt_int_param(config, p_name): + p = config.get(p_name) + if p is None: + return p + assert type(p) == int, p_name + " is not an integer" + return p + + +def get_str_list_param(config, p_name): + p = config.get(p_name) + if p is None: + sys.exit(p_name + " not specified in config.") + assert type(p) == list, p_name + "is not a list of strings" + for i in p: + assert type(i) == str, p_name + "is not a list of strings" + return p + + +def get_opt_param(config, p_name): + p = config.get(p_name) + return p + + +def get_shape_list(shape_string): + shape = [] + if shape_string == "": + return shape + for i in shape_string.split(","): + assert i.isnumeric(), "Given input shape has non-integer values" + shape.append(int(i)) + return shape + + +def parse_input_tensors(config): + input_t_info = {} + p = config.get("input_tensors") + if p is None: + return input_t_info + assert type(p) == dict, "Input tensors should be a dict of name=>shape" + for name, shape_str in p.items(): + input_t_info[name] = get_shape_list(shape_str) + return input_t_info + + +def parse_config(config): + model_fname = get_str_param(config, "model_name") + if not model_fname.endswith(".pb"): + sys.exit( + model_fname + + " is not a tensorflow protobuf file. Please supply " + + "a valid tensorflow protobuf model (.pb extension)" + ) + if not os.path.isfile(model_fname): + sys.exit(model_fname + " file does not exist") + target = get_str_param(config, "target").upper() + output_tensors = get_str_list_param(config, "output_tensors") + input_t_info = parse_input_tensors(config) + + save_weights = get_opt_bool_param(config, "save_weights") + scale = get_opt_int_param(config, "scale") + bitlen = get_opt_int_param(config, "bitlength") + modulo = get_opt_int_param(config, "modulo") + backend = get_opt_str_param(config, "backend") + disable_hlil_opts = get_opt_bool_param(config, "disable_all_hlil_opts") + disable_rmo = get_opt_bool_param(config, "disable_relu_maxpool_opts") + disable_garbage_collection = get_opt_bool_param( + config, "disable_garbage_collection" + ) + disable_trunc = get_opt_bool_param(config, "disable_trunc_opts") + + params = { + "model_name": model_fname, + "input_tensors": input_t_info, + "output_tensors": output_tensors, + "scale": scale, + "bitlength": bitlen, + "target": target, + "save_weights": save_weights, + "modulo": modulo, + "backend": backend, + "disable_all_hlil_opts": disable_hlil_opts, + "disable_relu_maxpool_opts": disable_rmo, + "disable_garbage_collection": disable_garbage_collection, + "disable_trunc_opts": disable_trunc, + } + return params + + +def get_params(config_fname): + config = get_config(config_fname) + params = parse_config(config) + return params diff --git a/Athos/CompilerScripts/tf_graph_io.py b/Athos/CompilerScripts/tf_graph_io.py index a50d7500..c0bc0507 100644 --- a/Athos/CompilerScripts/tf_graph_io.py +++ b/Athos/CompilerScripts/tf_graph_io.py @@ -18,6 +18,18 @@ def dump_pb(graph, filename): graph_def = graph.as_graph_def() f.write(graph_def.SerializeToString()) +def dump_graph_def_pb(graph_def, filename): + with tf.io.gfile.GFile(filename, 'wb') as f: + f.write(graph_def.SerializeToString()) + +def dump_pb_without_vars(graph, output_names, filename): + with tf.io.gfile.GFile(filename, 'wb') as f: + with tf.Session(graph=graph) as sess: + sess.run(tf.global_variables_initializer()) + graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess, + graph.as_graph_def(), output_names) + f.write(graph_def.SerializeToString()) + def save_model(graph, model_name): with graph.as_default(): with tf.Session() as sess: diff --git a/Athos/CompilerScripts/tf_graph_trans.py b/Athos/CompilerScripts/tf_graph_trans.py index 918006fe..5d53edd3 100644 --- a/Athos/CompilerScripts/tf_graph_trans.py +++ b/Athos/CompilerScripts/tf_graph_trans.py @@ -24,7 +24,6 @@ def remove_dead_nodes(graph, in_list, out_list): return opt_graph def convert_consts_to_var(graph, const_names_list): - const_var_names_pairs = [] ops_to_delete = [] with graph.as_default(): diff --git a/Athos/TFCompiler/DumpTFMtData.py b/Athos/TFCompiler/DumpTFMtData.py index 575a75f5..35982958 100644 --- a/Athos/TFCompiler/DumpTFMtData.py +++ b/Athos/TFCompiler/DumpTFMtData.py @@ -26,20 +26,46 @@ import tensorflow as tf from tensorflow.tools.graph_transforms import TransformGraph -def get_optimized_graph_def(output_tensor): - #First save the graph def - graph_def = tf.get_default_graph().as_graph_def() +def strip_variable_init_constants(graph_def, input_tensor_names, output_tensor_names): transforms = [ 'remove_nodes(op=Identity)', 'strip_unused_nodes', - 'fold_batch_norms', - 'fold_constants(ignore_errors=true)' ] - optimized_graph_def = TransformGraph(graph_def, [], [output_tensor.name], transforms) + optimized_graph_def = TransformGraph(graph_def, input_tensor_names, output_tensor_names, transforms) return optimized_graph_def +def save_graphdef(graph_def): + with open('./graphDef.mtdata', 'w') as f: + f.write(str(graph_def)) + +def save_sizeinfo(optimized_graph_def, sess, feed_dict): + # Save size information for tensors on which output depends + tensors_to_evaluate = [] + tensors_to_evaluate_names = [] + graph = sess.graph + for node in optimized_graph_def.node: + output_number = 0 + for cur_output in graph.get_operation_by_name(node.name).outputs: + tensors_to_evaluate.append(cur_output) + if output_number == 0: + tensor_name = node.name + else: + tensor_name = cur_output.name + tensors_to_evaluate_names.append(tensor_name) + output_number += 1 + tensors_evaluated = sess.run(tensors_to_evaluate, feed_dict) + tensors_shape = list(map(lambda x : x.shape, tensors_evaluated)) + + # Write size info in a file + with open('./sizeInfo.mtdata','w') as f: + for ii, curr in enumerate(tensors_to_evaluate_names): + curShape = tensors_shape[ii] + f.write(tensors_to_evaluate_names[ii] + ' ') + for dim in curShape: + f.write(str(dim)+' ') + f.write('\n') + def save_graph_metadata(output_tensor, sess, feed_dict): - #First save the graph def graph_def = sess.graph_def transforms = [ 'remove_nodes(op=Identity)', @@ -80,33 +106,23 @@ def save_graph_metadata(output_tensor, sess, feed_dict): return optimized_graph_def def updateWeightsForBN(optimized_graph_def, sess, feed_dict={}): - def findNodeInGraphDefWithName(graphDef, curName): - for curNode in graphDef.node: - if curNode.name == curName: - return curNode - return None - - print("Updating weights for BN...") - graph = sess.graph graphDef = optimized_graph_def for node in graphDef.node: - if (node.op == 'FusedBatchNorm' or node.op == 'FusedBatchNormV3'): - gamma = graph.get_operation_by_name(node.input[1]).outputs[0] - beta = graph.get_operation_by_name(node.input[2]).outputs[0] - mu = graph.get_operation_by_name(node.input[3]).outputs[0] - variance = graph.get_operation_by_name(node.input[4]).outputs[0] + if (node.op == 'FusedBatchNorm' or node.op == 'FusedBatchNormV3'): + gamma = graph.get_operation_by_name(node.input[1]).outputs[0] + beta = graph.get_operation_by_name(node.input[2]).outputs[0] + mu = graph.get_operation_by_name(node.input[3]).outputs[0] + variance = graph.get_operation_by_name(node.input[4]).outputs[0] - epsilon = node.attr['epsilon'].f - rsigma = tf.rsqrt(variance + epsilon) + epsilon = node.attr['epsilon'].f + rsigma = tf.rsqrt(variance + epsilon) - sess.run(tf.assign(gamma, gamma*rsigma)) - sess.run(tf.assign(beta, beta - gamma*mu)) - sess.run(tf.assign(mu, tf.zeros(tf.shape(mu)))) - sess.run(tf.assign(variance, tf.fill(tf.shape(variance), 1-epsilon))) - - print("BN weight updation done. Continuing...") + sess.run(tf.assign(gamma, gamma*rsigma), feed_dict) + sess.run(tf.assign(beta, beta - gamma*mu), feed_dict) + sess.run(tf.assign(mu, tf.zeros(tf.shape(mu))), feed_dict) + sess.run(tf.assign(variance, tf.fill(tf.shape(variance), 1-epsilon)), feed_dict) def dumpImageDataInt(imgData, filename, scalingFac, writeMode): print("Dumping image data...") @@ -150,3 +166,19 @@ def numpy_float_array_to_float_val_str(input_array): for val in numpy.nditer(input_array): chunk += str(val) + '\n' return chunk + +def save_weights(optimized_graph_def, sess, feed_dict, filename, scaling_factor): + graph = sess.graph + varNames = [ + node.name + for node in optimized_graph_def.node + if node.op in ["VariableV2", "Variable"] + ] + graph_vars = [graph.get_operation_by_name(i).outputs[0] for i in varNames] + updateWeightsForBN(optimized_graph_def, sess, feed_dict) + values = sess.run(graph_vars, feed_dict) + with open(filename, "w") as ff: + for val in values: + for xx in numpy.nditer(val, order="C"): + ff.write(str(int(xx * (1 << scaling_factor))) + " ") + ff.write("\n") \ No newline at end of file diff --git a/Athos/TFCompiler/ProcessTFGraph.py b/Athos/TFCompiler/ProcessTFGraph.py index 915ef1e5..854cb80c 100644 --- a/Athos/TFCompiler/ProcessTFGraph.py +++ b/Athos/TFCompiler/ProcessTFGraph.py @@ -23,7 +23,8 @@ ''' import os, sys -sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'SeeDot')) #Add SeeDot directory to path +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'SeeDot')) #Add SeeDot directory to path import Graph, AST.AST as AST, _pickle as pickle, os from TFNodesAST import TFNodesAST @@ -50,7 +51,7 @@ def generateIRCode(graph, extraInfoDict): mtdAST = MtdAST() for curNode in graph.getAllNodesRef(): for curInp in curNode.getInputsRef(): - assert(curInp in dictNodeNameToOutVarStr) #Consequence of topological sorting of the TF graph + assert(curInp in dictNodeNameToOutVarStr), "input={} expected as input but not yet processed".format(curInp) #Consequence of topological sorting of the TF graph (assignedVarAST, curAsts) = generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraInfoDict) for outputName, curAst in curAsts.items(): mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : curNode.getOp(), @@ -112,7 +113,6 @@ def prefixAllPlaceHolderNodes(graph): remNodes.append(curNode) graph.setNodesList(placeHolderNodes + remNodes) - # List of Optimisations # 1. Split squared difference into (a-b)*(a-b) def simplifyGraph(graph): @@ -138,15 +138,9 @@ def simplifyGraph(graph): newNodes.append(curNode) graph.setNodesList(newNodes) -def main(): +def process_tf_graph(filename): sys.setrecursionlimit(10000) - # First read the graph file - if (len(sys.argv) < 2): - print("TF python file unspecified.", file=sys.stderr) - exit(1) - - filename = sys.argv[1] folderName = os.path.dirname(filename) graphFileName = os.path.join(folderName, 'graphDef.mtdata') graph = Graph.Graph() @@ -191,4 +185,9 @@ def main(): pickle.dump(program, f) if __name__ == "__main__": - main() + if (len(sys.argv) < 2): + print("TF python file unspecified.", file=sys.stderr) + exit(1) + + filename = sys.argv[1] + process_tf_graph(filename) \ No newline at end of file From a349497b8d823553951b33a2e8f0a57c72936d35 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Mon, 11 Jan 2021 02:22:30 +0530 Subject: [PATCH 23/72] Pass key directory as argument for Porthos. Previously programs linked with porthos, expected the keys to be in the current_directory/files folder. Now users can explicitly pass the directory while invoking the program. The keys still have to be named keyA, keyAB, keyB and keyD but they can be located in arbitary folders. So party 0 will invoke the program as: ./program 0 files/addresses_file path/to/keys/dir --- EzPC/EzPC/codegenporthos.ml | 49 ++++++++--------- Porthos/files/{ => keys}/keyA | 0 Porthos/files/{ => keys}/keyAB | 0 Porthos/files/{ => keys}/keyB | 0 Porthos/files/{ => keys}/keyD | 0 Porthos/party0.sh | 2 +- Porthos/party1.sh | 2 +- Porthos/party2.sh | 2 +- .../example_neural_nets/mainDenseNet121.cpp | 53 ++++++++++--------- .../src/example_neural_nets/mainResNet50.cpp | 53 ++++++++++--------- .../example_neural_nets/mainSqNetImgNet.cpp | 53 ++++++++++--------- Porthos/src/secondary.cpp | 8 +-- 12 files changed, 114 insertions(+), 108 deletions(-) rename Porthos/files/{ => keys}/keyA (100%) rename Porthos/files/{ => keys}/keyAB (100%) rename Porthos/files/{ => keys}/keyB (100%) rename Porthos/files/{ => keys}/keyD (100%) diff --git a/EzPC/EzPC/codegenporthos.ml b/EzPC/EzPC/codegenporthos.ml index ac7c3d57..87119cb5 100644 --- a/EzPC/EzPC/codegenporthos.ml +++ b/EzPC/EzPC/codegenporthos.ml @@ -566,6 +566,7 @@ let porthos_prelude_string :string = #include \"ezpc.h\"\n\ \n\ extern int partyNum;\n\ +extern string key_directory;\n\ vector toFreeMemoryLaterArr;\n\ int NUM_OF_PARTIES;\n\ \n\ @@ -608,37 +609,37 @@ string whichNetwork = \"Your Network\";\n\ show_porthos_mode();\n\ string indep_key_location, common_key_location;\n\ if(partyNum == PARTY_A){\n\ - indep_key_location = \"files/keyA\";\n\ - common_key_location = \"files/keyAB\";\n\ + indep_key_location = key_directory + \"/keyA\";\n\ + common_key_location = key_directory + \"/keyAB\";\n\ }\n\ else if(partyNum == PARTY_B){\n\ - indep_key_location = \"files/keyB\";\n\ - common_key_location = \"files/keyAB\";\n\ + indep_key_location = key_directory + \"/keyB\";\n\ + common_key_location = key_directory + \"/keyAB\";\n\ }\n\ else{\n\ - indep_key_location = \"files/keyB\";\n\ - common_key_location = \"files/keyAB\";\n\ + indep_key_location = key_directory + \"/keyB\";\n\ + common_key_location = key_directory + \"/keyAB\";\n\ }\n\ aes_indep = new AESObject(indep_key_location);\n\ aes_common = new AESObject(common_key_location);\n\ -aes_a_1 = new AESObject(\"files/keyD\");\n\ -aes_a_2 = new AESObject(\"files/keyD\");\n\ -aes_b_1 = new AESObject(\"files/keyD\");\n\ -aes_b_2 = new AESObject(\"files/keyD\");\n\ -aes_c_1 = new AESObject(\"files/keyD\");\n\ -aes_share_conv_bit_shares_p0_p2 = new AESObject(\"files/keyD\");\n\ -aes_share_conv_bit_shares_p1_p2 = new AESObject(\"files/keyD\");\n\ -aes_share_conv_shares_mod_odd_p0_p2 = new AESObject(\"files/keyD\");\n\ -aes_share_conv_shares_mod_odd_p1_p2 = new AESObject(\"files/keyD\");\n\ -aes_comp_msb_shares_lsb_p0_p2 = new AESObject(\"files/keyD\");\n\ -aes_comp_msb_shares_lsb_p1_p2 = new AESObject(\"files/keyD\");\n\ -aes_comp_msb_shares_bit_vec_p0_p2 = new AESObject(\"files/keyD\");\n\ -aes_comp_msb_shares_bit_vec_p1_p2 = new AESObject(\"files/keyD\");\n\ -aes_conv_opti_a_1 = new AESObject(\"files/keyD\");\n\ -aes_conv_opti_a_2 = new AESObject(\"files/keyD\");\n\ -aes_conv_opti_b_1 = new AESObject(\"files/keyD\");\n\ -aes_conv_opti_b_2 = new AESObject(\"files/keyD\");\n\ -aes_conv_opti_c_1 = new AESObject(\"files/keyD\");\n\ +aes_a_1 = new AESObject(key_directory + \"/keyD\");\n\ +aes_a_2 = new AESObject(key_directory + \"/keyD\");\n\ +aes_b_1 = new AESObject(key_directory + \"/keyD\");\n\ +aes_b_2 = new AESObject(key_directory + \"/keyD\");\n\ +aes_c_1 = new AESObject(key_directory + \"/keyD\");\n\ +aes_share_conv_bit_shares_p0_p2 = new AESObject(key_directory + \"/keyD\");\n\ +aes_share_conv_bit_shares_p1_p2 = new AESObject(key_directory + \"/keyD\");\n\ +aes_share_conv_shares_mod_odd_p0_p2 = new AESObject(key_directory + \"/keyD\");\n\ +aes_share_conv_shares_mod_odd_p1_p2 = new AESObject(key_directory + \"/keyD\");\n\ +aes_comp_msb_shares_lsb_p0_p2 = new AESObject(key_directory + \"/keyD\");\n\ +aes_comp_msb_shares_lsb_p1_p2 = new AESObject(key_directory + \"/keyD\");\n\ +aes_comp_msb_shares_bit_vec_p0_p2 = new AESObject(key_directory + \"/keyD\");\n\ +aes_comp_msb_shares_bit_vec_p1_p2 = new AESObject(key_directory + \"/keyD\");\n\ +aes_conv_opti_a_1 = new AESObject(key_directory + \"/keyD\");\n\ +aes_conv_opti_a_2 = new AESObject(key_directory + \"/keyD\");\n\ +aes_conv_opti_b_1 = new AESObject(key_directory + \"/keyD\");\n\ +aes_conv_opti_b_2 = new AESObject(key_directory + \"/keyD\");\n\ +aes_conv_opti_c_1 = new AESObject(key_directory + \"/keyD\");\n\ aes_parallel = new ParallelAESObject(common_key_location);\n\ \n\ if (MPC)\n\ diff --git a/Porthos/files/keyA b/Porthos/files/keys/keyA similarity index 100% rename from Porthos/files/keyA rename to Porthos/files/keys/keyA diff --git a/Porthos/files/keyAB b/Porthos/files/keys/keyAB similarity index 100% rename from Porthos/files/keyAB rename to Porthos/files/keys/keyAB diff --git a/Porthos/files/keyB b/Porthos/files/keys/keyB similarity index 100% rename from Porthos/files/keyB rename to Porthos/files/keys/keyB diff --git a/Porthos/files/keyD b/Porthos/files/keys/keyD similarity index 100% rename from Porthos/files/keyD rename to Porthos/files/keys/keyD diff --git a/Porthos/party0.sh b/Porthos/party0.sh index 4d193ffc..62303101 100755 --- a/Porthos/party0.sh +++ b/Porthos/party0.sh @@ -13,6 +13,6 @@ then exit 1 fi -./src/build/bin/$1 0 files/addresses +./src/build/bin/$1 0 files/addresses files/keys diff --git a/Porthos/party1.sh b/Porthos/party1.sh index cb857734..77eb82df 100755 --- a/Porthos/party1.sh +++ b/Porthos/party1.sh @@ -13,5 +13,5 @@ then exit 1 fi -./src/build/bin/$1 1 files/addresses +./src/build/bin/$1 1 files/addresses files/keys diff --git a/Porthos/party2.sh b/Porthos/party2.sh index d20bf226..ffc9019d 100755 --- a/Porthos/party2.sh +++ b/Porthos/party2.sh @@ -13,6 +13,6 @@ then exit 1 fi -./src/build/bin/$1 2 files/addresses +./src/build/bin/$1 2 files/addresses files/keys diff --git a/Porthos/src/example_neural_nets/mainDenseNet121.cpp b/Porthos/src/example_neural_nets/mainDenseNet121.cpp index 0e72fec5..28cfeb43 100644 --- a/Porthos/src/example_neural_nets/mainDenseNet121.cpp +++ b/Porthos/src/example_neural_nets/mainDenseNet121.cpp @@ -53,6 +53,7 @@ return os; #include "ezpc.h" extern int partyNum; +extern string key_directory; vector toFreeMemoryLaterArr; int NUM_OF_PARTIES; @@ -1822,9 +1823,9 @@ arr[i1][i2][i3][i4] = reshapedArr[linIdx]; } ClearMemSecret1(size, reshapedArr); } - -extern int instanceID; + +extern int instanceID; int main(int argc, char** argv) { parseInputs(argc, argv); @@ -1832,37 +1833,37 @@ string whichNetwork = "Your Network"; show_porthos_mode(); string indep_key_location, common_key_location; if(partyNum == PARTY_A){ -indep_key_location = "files/keyA"; -common_key_location = "files/keyAB"; +indep_key_location = key_directory + "/keyA"; +common_key_location = key_directory + "/keyAB"; } else if(partyNum == PARTY_B){ -indep_key_location = "files/keyB"; -common_key_location = "files/keyAB"; +indep_key_location = key_directory + "/keyB"; +common_key_location = key_directory + "/keyAB"; } else{ -indep_key_location = "files/keyB"; -common_key_location = "files/keyAB"; +indep_key_location = key_directory + "/keyB"; +common_key_location = key_directory + "/keyAB"; } aes_indep = new AESObject(indep_key_location); aes_common = new AESObject(common_key_location); -aes_a_1 = new AESObject("files/keyD"); -aes_a_2 = new AESObject("files/keyD"); -aes_b_1 = new AESObject("files/keyD"); -aes_b_2 = new AESObject("files/keyD"); -aes_c_1 = new AESObject("files/keyD"); -aes_share_conv_bit_shares_p0_p2 = new AESObject("files/keyD"); -aes_share_conv_bit_shares_p1_p2 = new AESObject("files/keyD"); -aes_share_conv_shares_mod_odd_p0_p2 = new AESObject("files/keyD"); -aes_share_conv_shares_mod_odd_p1_p2 = new AESObject("files/keyD"); -aes_comp_msb_shares_lsb_p0_p2 = new AESObject("files/keyD"); -aes_comp_msb_shares_lsb_p1_p2 = new AESObject("files/keyD"); -aes_comp_msb_shares_bit_vec_p0_p2 = new AESObject("files/keyD"); -aes_comp_msb_shares_bit_vec_p1_p2 = new AESObject("files/keyD"); -aes_conv_opti_a_1 = new AESObject("files/keyD"); -aes_conv_opti_a_2 = new AESObject("files/keyD"); -aes_conv_opti_b_1 = new AESObject("files/keyD"); -aes_conv_opti_b_2 = new AESObject("files/keyD"); -aes_conv_opti_c_1 = new AESObject("files/keyD"); +aes_a_1 = new AESObject(key_directory + "/keyD"); +aes_a_2 = new AESObject(key_directory + "/keyD"); +aes_b_1 = new AESObject(key_directory + "/keyD"); +aes_b_2 = new AESObject(key_directory + "/keyD"); +aes_c_1 = new AESObject(key_directory + "/keyD"); +aes_share_conv_bit_shares_p0_p2 = new AESObject(key_directory + "/keyD"); +aes_share_conv_bit_shares_p1_p2 = new AESObject(key_directory + "/keyD"); +aes_share_conv_shares_mod_odd_p0_p2 = new AESObject(key_directory + "/keyD"); +aes_share_conv_shares_mod_odd_p1_p2 = new AESObject(key_directory + "/keyD"); +aes_comp_msb_shares_lsb_p0_p2 = new AESObject(key_directory + "/keyD"); +aes_comp_msb_shares_lsb_p1_p2 = new AESObject(key_directory + "/keyD"); +aes_comp_msb_shares_bit_vec_p0_p2 = new AESObject(key_directory + "/keyD"); +aes_comp_msb_shares_bit_vec_p1_p2 = new AESObject(key_directory + "/keyD"); +aes_conv_opti_a_1 = new AESObject(key_directory + "/keyD"); +aes_conv_opti_a_2 = new AESObject(key_directory + "/keyD"); +aes_conv_opti_b_1 = new AESObject(key_directory + "/keyD"); +aes_conv_opti_b_2 = new AESObject(key_directory + "/keyD"); +aes_conv_opti_c_1 = new AESObject(key_directory + "/keyD"); aes_parallel = new ParallelAESObject(common_key_location); if (MPC) diff --git a/Porthos/src/example_neural_nets/mainResNet50.cpp b/Porthos/src/example_neural_nets/mainResNet50.cpp index fa2b2a11..7d482eeb 100644 --- a/Porthos/src/example_neural_nets/mainResNet50.cpp +++ b/Porthos/src/example_neural_nets/mainResNet50.cpp @@ -53,6 +53,7 @@ return os; #include "ezpc.h" extern int partyNum; +extern string key_directory; vector toFreeMemoryLaterArr; int NUM_OF_PARTIES; @@ -1822,9 +1823,9 @@ arr[i1][i2][i3][i4] = reshapedArr[linIdx]; } ClearMemSecret1(size, reshapedArr); } - -extern int instanceID; + +extern int instanceID; int main(int argc, char** argv) { parseInputs(argc, argv); @@ -1832,37 +1833,37 @@ string whichNetwork = "Your Network"; show_porthos_mode(); string indep_key_location, common_key_location; if(partyNum == PARTY_A){ -indep_key_location = "files/keyA"; -common_key_location = "files/keyAB"; +indep_key_location = key_directory + "/keyA"; +common_key_location = key_directory + "/keyAB"; } else if(partyNum == PARTY_B){ -indep_key_location = "files/keyB"; -common_key_location = "files/keyAB"; +indep_key_location = key_directory + "/keyB"; +common_key_location = key_directory + "/keyAB"; } else{ -indep_key_location = "files/keyB"; -common_key_location = "files/keyAB"; +indep_key_location = key_directory + "/keyB"; +common_key_location = key_directory + "/keyAB"; } aes_indep = new AESObject(indep_key_location); aes_common = new AESObject(common_key_location); -aes_a_1 = new AESObject("files/keyD"); -aes_a_2 = new AESObject("files/keyD"); -aes_b_1 = new AESObject("files/keyD"); -aes_b_2 = new AESObject("files/keyD"); -aes_c_1 = new AESObject("files/keyD"); -aes_share_conv_bit_shares_p0_p2 = new AESObject("files/keyD"); -aes_share_conv_bit_shares_p1_p2 = new AESObject("files/keyD"); -aes_share_conv_shares_mod_odd_p0_p2 = new AESObject("files/keyD"); -aes_share_conv_shares_mod_odd_p1_p2 = new AESObject("files/keyD"); -aes_comp_msb_shares_lsb_p0_p2 = new AESObject("files/keyD"); -aes_comp_msb_shares_lsb_p1_p2 = new AESObject("files/keyD"); -aes_comp_msb_shares_bit_vec_p0_p2 = new AESObject("files/keyD"); -aes_comp_msb_shares_bit_vec_p1_p2 = new AESObject("files/keyD"); -aes_conv_opti_a_1 = new AESObject("files/keyD"); -aes_conv_opti_a_2 = new AESObject("files/keyD"); -aes_conv_opti_b_1 = new AESObject("files/keyD"); -aes_conv_opti_b_2 = new AESObject("files/keyD"); -aes_conv_opti_c_1 = new AESObject("files/keyD"); +aes_a_1 = new AESObject(key_directory + "/keyD"); +aes_a_2 = new AESObject(key_directory + "/keyD"); +aes_b_1 = new AESObject(key_directory + "/keyD"); +aes_b_2 = new AESObject(key_directory + "/keyD"); +aes_c_1 = new AESObject(key_directory + "/keyD"); +aes_share_conv_bit_shares_p0_p2 = new AESObject(key_directory + "/keyD"); +aes_share_conv_bit_shares_p1_p2 = new AESObject(key_directory + "/keyD"); +aes_share_conv_shares_mod_odd_p0_p2 = new AESObject(key_directory + "/keyD"); +aes_share_conv_shares_mod_odd_p1_p2 = new AESObject(key_directory + "/keyD"); +aes_comp_msb_shares_lsb_p0_p2 = new AESObject(key_directory + "/keyD"); +aes_comp_msb_shares_lsb_p1_p2 = new AESObject(key_directory + "/keyD"); +aes_comp_msb_shares_bit_vec_p0_p2 = new AESObject(key_directory + "/keyD"); +aes_comp_msb_shares_bit_vec_p1_p2 = new AESObject(key_directory + "/keyD"); +aes_conv_opti_a_1 = new AESObject(key_directory + "/keyD"); +aes_conv_opti_a_2 = new AESObject(key_directory + "/keyD"); +aes_conv_opti_b_1 = new AESObject(key_directory + "/keyD"); +aes_conv_opti_b_2 = new AESObject(key_directory + "/keyD"); +aes_conv_opti_c_1 = new AESObject(key_directory + "/keyD"); aes_parallel = new ParallelAESObject(common_key_location); if (MPC) diff --git a/Porthos/src/example_neural_nets/mainSqNetImgNet.cpp b/Porthos/src/example_neural_nets/mainSqNetImgNet.cpp index aa5ddf18..e2c8d19e 100644 --- a/Porthos/src/example_neural_nets/mainSqNetImgNet.cpp +++ b/Porthos/src/example_neural_nets/mainSqNetImgNet.cpp @@ -53,6 +53,7 @@ return os; #include "ezpc.h" extern int partyNum; +extern string key_directory; vector toFreeMemoryLaterArr; int NUM_OF_PARTIES; @@ -1822,9 +1823,9 @@ arr[i1][i2][i3][i4] = reshapedArr[linIdx]; } ClearMemSecret1(size, reshapedArr); } - -extern int instanceID; + +extern int instanceID; int main(int argc, char** argv) { parseInputs(argc, argv); @@ -1832,37 +1833,37 @@ string whichNetwork = "Your Network"; show_porthos_mode(); string indep_key_location, common_key_location; if(partyNum == PARTY_A){ -indep_key_location = "files/keyA"; -common_key_location = "files/keyAB"; +indep_key_location = key_directory + "/keyA"; +common_key_location = key_directory + "/keyAB"; } else if(partyNum == PARTY_B){ -indep_key_location = "files/keyB"; -common_key_location = "files/keyAB"; +indep_key_location = key_directory + "/keyB"; +common_key_location = key_directory + "/keyAB"; } else{ -indep_key_location = "files/keyB"; -common_key_location = "files/keyAB"; +indep_key_location = key_directory + "/keyB"; +common_key_location = key_directory + "/keyAB"; } aes_indep = new AESObject(indep_key_location); aes_common = new AESObject(common_key_location); -aes_a_1 = new AESObject("files/keyD"); -aes_a_2 = new AESObject("files/keyD"); -aes_b_1 = new AESObject("files/keyD"); -aes_b_2 = new AESObject("files/keyD"); -aes_c_1 = new AESObject("files/keyD"); -aes_share_conv_bit_shares_p0_p2 = new AESObject("files/keyD"); -aes_share_conv_bit_shares_p1_p2 = new AESObject("files/keyD"); -aes_share_conv_shares_mod_odd_p0_p2 = new AESObject("files/keyD"); -aes_share_conv_shares_mod_odd_p1_p2 = new AESObject("files/keyD"); -aes_comp_msb_shares_lsb_p0_p2 = new AESObject("files/keyD"); -aes_comp_msb_shares_lsb_p1_p2 = new AESObject("files/keyD"); -aes_comp_msb_shares_bit_vec_p0_p2 = new AESObject("files/keyD"); -aes_comp_msb_shares_bit_vec_p1_p2 = new AESObject("files/keyD"); -aes_conv_opti_a_1 = new AESObject("files/keyD"); -aes_conv_opti_a_2 = new AESObject("files/keyD"); -aes_conv_opti_b_1 = new AESObject("files/keyD"); -aes_conv_opti_b_2 = new AESObject("files/keyD"); -aes_conv_opti_c_1 = new AESObject("files/keyD"); +aes_a_1 = new AESObject(key_directory + "/keyD"); +aes_a_2 = new AESObject(key_directory + "/keyD"); +aes_b_1 = new AESObject(key_directory + "/keyD"); +aes_b_2 = new AESObject(key_directory + "/keyD"); +aes_c_1 = new AESObject(key_directory + "/keyD"); +aes_share_conv_bit_shares_p0_p2 = new AESObject(key_directory + "/keyD"); +aes_share_conv_bit_shares_p1_p2 = new AESObject(key_directory + "/keyD"); +aes_share_conv_shares_mod_odd_p0_p2 = new AESObject(key_directory + "/keyD"); +aes_share_conv_shares_mod_odd_p1_p2 = new AESObject(key_directory + "/keyD"); +aes_comp_msb_shares_lsb_p0_p2 = new AESObject(key_directory + "/keyD"); +aes_comp_msb_shares_lsb_p1_p2 = new AESObject(key_directory + "/keyD"); +aes_comp_msb_shares_bit_vec_p0_p2 = new AESObject(key_directory + "/keyD"); +aes_comp_msb_shares_bit_vec_p1_p2 = new AESObject(key_directory + "/keyD"); +aes_conv_opti_a_1 = new AESObject(key_directory + "/keyD"); +aes_conv_opti_a_2 = new AESObject(key_directory + "/keyD"); +aes_conv_opti_b_1 = new AESObject(key_directory + "/keyD"); +aes_conv_opti_b_2 = new AESObject(key_directory + "/keyD"); +aes_conv_opti_c_1 = new AESObject(key_directory + "/keyD"); aes_parallel = new ParallelAESObject(common_key_location); if (MPC) diff --git a/Porthos/src/secondary.cpp b/Porthos/src/secondary.cpp index d2d5762f..a087d474 100644 --- a/Porthos/src/secondary.cpp +++ b/Porthos/src/secondary.cpp @@ -28,6 +28,7 @@ using namespace std; //this player number int partyNum; +string key_directory; //aes_key of the party char *party_aes_key; @@ -48,18 +49,19 @@ void parseInputs(int argc, char* argv[]) { assert((sizeof(double) == sizeof(porthosSecretType)) && "sizeof(double) != sizeof(porthosSecretType)"); - if(argc == 3){ + if(argc == 4){ instanceID = 0; } - else if(argc == 4){ + else if(argc == 5){ instanceID = atoi(argv[3]); } else{ porthos_throw_error(PARSE_ERROR); - cout<<"Porthos expects either 3 or 4 CLI arguments!"< Date: Mon, 11 Jan 2021 02:30:15 +0530 Subject: [PATCH 24/72] Always checkout Eigen3 and SEAL while building SCI Locally check them out in extern/ directory so that programs can be linked against SCI without manually putting them in the networks directory and modifying the CMake file. See CompileTFGraph.py for seeing how to build and link programs directly. --- SCI/src/LinearHE/CMakeLists.txt | 55 +++++++++++++++------------------ 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/SCI/src/LinearHE/CMakeLists.txt b/SCI/src/LinearHE/CMakeLists.txt index cf1f91fb..775b7c1a 100644 --- a/SCI/src/LinearHE/CMakeLists.txt +++ b/SCI/src/LinearHE/CMakeLists.txt @@ -1,15 +1,13 @@ find_package(OpenMP REQUIRED) # set(CMAKE_FIND_DEBUG_MODE 1) -find_package(SEAL 3.3.2 EXACT QUIET) -if (NOT SEAL_FOUND) - message(STATUS "SEAL 3.3.2 was not found: clone and install SEAL locally") - if (NOT EXISTS "${PROJECT_SOURCE_DIR}/extern/SEAL/native/src/CMakeLists.txt") - find_package(Git REQUIRED) - message(STATUS "initialize Git submodule: extern/SEAL") - execute_process(COMMAND git submodule update --init --recursive extern/SEAL - WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}") - endif () +message(STATUS "SEAL 3.3.2 was not found: clone and install SEAL locally") +if (NOT EXISTS "${PROJECT_SOURCE_DIR}/extern/SEAL/native/src/CMakeLists.txt") + find_package(Git REQUIRED) + message(STATUS "initialize Git submodule: extern/SEAL") + execute_process(COMMAND git submodule update --init --recursive extern/SEAL + WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}") +endif () if(APPLE) execute_process(COMMAND ${CMAKE_COMMAND} -DCMAKE_INSTALL_PREFIX=${PROJECT_SOURCE_DIR}/build . -DCMAKE_C_COMPILER=${MAC_GCC} -DCMAKE_CXX_COMPILER=${MAC_GPP} @@ -18,29 +16,26 @@ else () execute_process(COMMAND ${CMAKE_COMMAND} -DCMAKE_INSTALL_PREFIX=${PROJECT_SOURCE_DIR}/build . WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}/extern/SEAL/native/src") endif() - execute_process(COMMAND ${CMAKE_COMMAND} --build . --target install - WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}/extern/SEAL/native/src") - find_package(SEAL 3.3.2 EXACT REQUIRED PATHS "${PROJECT_SOURCE_DIR}/build/") -endif () +execute_process(COMMAND ${CMAKE_COMMAND} --build . --target install + WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}/extern/SEAL/native/src") +find_package(SEAL 3.3.2 EXACT REQUIRED PATHS "${PROJECT_SOURCE_DIR}/build/") -find_package(Eigen3 3.3 NO_MODULE QUIET) -if (NOT Eigen3_FOUND) - message(STATUS "Eigen 3.3 was not found: clone and install Eigen3 locally") - if (NOT EXISTS "${PROJECT_SOURCE_DIR}/extern/eigen/CMakeLists.txt") - find_package(Git REQUIRED) - message(STATUS "initialize Git submodule: extern/eigen") - execute_process(COMMAND git submodule update --init --recursive extern/eigen - WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}") - endif () - execute_process(COMMAND ${CMAKE_COMMAND} -E make_directory build - WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}/extern/eigen/") - execute_process(COMMAND ${CMAKE_COMMAND} -DCMAKE_INSTALL_PREFIX=${PROJECT_SOURCE_DIR}/build .. - WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}/extern/eigen/build") - execute_process(COMMAND ${CMAKE_COMMAND} --build .. --target install - WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}/extern/eigen/build") - message(STATUS "${PROJECT_SOURCE_DIR}") - find_package(Eigen3 3.3 REQUIRED NO_MODULE PATHS "${PROJECT_SOURCE_DIR}/build/") +message(STATUS "Eigen 3.3 was not found: clone and install Eigen3 locally") +if (NOT EXISTS "${PROJECT_SOURCE_DIR}/extern/eigen/CMakeLists.txt") + find_package(Git REQUIRED) + message(STATUS "initialize Git submodule: extern/eigen") + execute_process(COMMAND git submodule update --init --recursive extern/eigen + WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}") endif () +execute_process(COMMAND ${CMAKE_COMMAND} -E make_directory build + WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}/extern/eigen/") +execute_process(COMMAND ${CMAKE_COMMAND} -DCMAKE_INSTALL_PREFIX=${PROJECT_SOURCE_DIR}/build .. + WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}/extern/eigen/build") +execute_process(COMMAND ${CMAKE_COMMAND} --build .. --target install + WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}/extern/eigen/build") +message(STATUS "${PROJECT_SOURCE_DIR}") +find_package(Eigen3 3.3 REQUIRED NO_MODULE PATHS "${PROJECT_SOURCE_DIR}/build/") + add_library(SCI-LinearHE conv-field.cpp From 4346e966990d18f591d6c12e4a8a23dd4efa782a Mon Sep 17 00:00:00 2001 From: Bhatu Date: Mon, 11 Jan 2021 02:36:07 +0530 Subject: [PATCH 25/72] Add support to directly link code with SCI, Porthos Now user does not have to manually copy generated cpp files into network directories of SCI and Porthos and edit CMakeFiles. Once SCI and Porthos libraries are built the CompileTFGraph.py can directly compile the generated cpp files and link them against SCI/Porthos. --- Athos/CompileTFGraph.py | 65 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 57 insertions(+), 8 deletions(-) diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index 71d893ec..8ee57c84 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -56,10 +56,7 @@ def parse_args(): args = parser.parse_args() return args - -if __name__ == "__main__": - args = parse_args() - params = parse_config.get_params(args.config) +def generate_code(params): # Mandatory model_name = params["model_name"] input_tensor_info = params["input_tensors"] @@ -79,11 +76,12 @@ def parse_args(): "CPPRING", ], "Target must be any of ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC" + cwd = os.getcwd() athos_dir = os.path.dirname(os.path.abspath(__file__)) model_abs_path = os.path.abspath(model_name) model_abs_dir = os.path.dirname(model_abs_path) # Generate graphdef and sizeInfo metadata - compile_tf.compile( + weights_path = compile_tf.compile( model_name, input_tensor_info, output_tensors, scale, save_weights ) @@ -91,7 +89,7 @@ def parse_args(): Athos.process_tf_graph(model_abs_path) # Compile to ezpc - model_base_name = model_name[:-3] + model_base_name = os.path.basename(model_abs_path)[:-3] ezpc_file_name = "{mname}_{bl}_{target}.ezpc".format( mname=model_base_name, bl=bitlength, target=target.lower() ) @@ -160,11 +158,12 @@ def parse_args(): modulo = params["modulo"] backend = "OT" if params["backend"] is None else params["backend"] ezpc_dir = os.path.join(athos_dir, "../EzPC/EzPC/") + # Copy generated code to the ezpc directory os.system("cp {ezpc} {ezpc_dir}".format(ezpc=ezpc_abs_path, ezpc_dir=ezpc_dir)) os.chdir(ezpc_dir) ezpc_args = "" - ezpc_args += "--bitlen {bl} --codegen {target} --disable-tac".format( - bl=lib_bitlength, target=target + ezpc_args += "--bitlen {bl} --codegen {target} --disable-tac ".format( + bl=bitlength, target=target ) output_name = ezpc_file_name[:-5] + "0.cpp" if modulo is not None: @@ -174,9 +173,59 @@ def parse_args(): output_name = ezpc_file_name[:-5] + "_{}0.cpp".format(backend.upper()) if target in ["PORTHOS"]: ezpc_args += "--sf {} ".format(scale) + os.system( "eval `opam config env`; ./ezpc.sh {} ".format(ezpc_file_name) + ezpc_args ) os.system( "cp {output} {model_dir} ".format(output=output_name, model_dir=model_abs_dir) ) + output_file = os.path.join(model_abs_dir, output_name) + + if target == "PORTHOS2PC": + program_name = model_base_name + "_" + target + "_" + backend + ".out" + else: + program_name = model_base_name + "_" + target + ".out" + program_path = os.path.join(model_abs_dir, program_name) + os.chdir(model_abs_dir) + if target in [ "CPP", "CPPRING"]: + os.system( + "g++ -O3 -w {file} -o {output}".format(file=output_file, output=program_path) + ) + elif target == "PORTHOS": + porthos_src = os.path.join(athos_dir, "..", "Porthos", "src") + porthos_lib = os.path.join(porthos_src, "build", "lib") + if os.path.exists(porthos_lib): + os.system( + """g++ -O3 -fopenmp -pthread -w -march=native -msse4.1 -maes -mpclmul \ + -mrdseed -fpermissive -fpic -std=c++17 -L {porthos_lib} -I {porthos_headers} {file} \ + -lPorthos-Protocols -lssl -lcrypto -lrt -lboost_system \ + -o {output}""".format(porthos_lib=porthos_lib, porthos_headers=porthos_src, + file=output_file, output=program_path) + ) + else: + print("Not compiling generated code. Please follow the readme and build Porthos.") + elif target == "PORTHOS2PC": + sci = os.path.join(athos_dir, "..", "SCI") + sci_src = os.path.join(sci, "src") + sci_lib = os.path.join(sci, "build", "lib") + eigen_path = os.path.join(sci, "extern", "eigen") + seal_lib_path = os.path.join(sci, "extern", "SEAL", "native", "lib") + if os.path.exists(sci_lib): + os.system( + """g++ -O3 -fpermissive -pthread -w -maes -msse4.1 -mavx -mavx2 -mrdseed \ + -faligned-new -std=c++17 -fopenmp -I {eigen} -I {sci_src} {file} \ + -L {sci_lib} -lSCI-LinearHE -L {seal} -lseal -lssl -lcrypto \ + -o {output}""".format(eigen=eigen_path, sci_src=sci_src, + file=output_file,sci_lib=sci_lib,seal=seal_lib_path, output=program_path) + ) + else: + print("Not compiling generated code. Please follow the readme and build SCI.") + + os.chdir(cwd) + return (program_path, weights_path) + +if __name__ == "__main__": + args = parse_args() + params = parse_config.get_params(args.config) + generate_code(params) \ No newline at end of file From 2aa79b128f9892c8965691b47c6c5ad983d65e95 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Mon, 11 Jan 2021 02:42:41 +0530 Subject: [PATCH 26/72] Minor improvements to compiler scripts, handle const to var better for some ops --- Athos/CompilerScripts/compile_tf.py | 6 ++++-- Athos/CompilerScripts/grappler.py | 13 +++++++++++++ Athos/CompilerScripts/parse_config.py | 2 +- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/Athos/CompilerScripts/compile_tf.py b/Athos/CompilerScripts/compile_tf.py index fa9accc8..06ac70ac 100644 --- a/Athos/CompilerScripts/compile_tf.py +++ b/Athos/CompilerScripts/compile_tf.py @@ -120,11 +120,12 @@ def compile(model_fname, input_t_info, output_t_names, scaling_factor, save_weig ) tf_graph_io.dump_graph_def_pb( - optimized_graph_def, "optimised_" + model_fname + optimized_graph_def, "optimised_" + model_name + ".pb" ) DumpTFMtData.save_graphdef(optimized_graph_def) DumpTFMtData.save_sizeinfo(optimized_graph_def, sess, feed_dict) print("Model compilation done.") + weights_path = "" if save_weights: weights_fname = ( model_name @@ -140,8 +141,9 @@ def compile(model_fname, input_t_info, output_t_names, scaling_factor, save_weig DumpTFMtData.save_weights( optimized_graph_def, sess, feed_dict, weights_fname, scaling_factor ) + weights_path = os.path.join(model_dir, weights_fname) os.chdir(cwd) - return + return weights_path def parse_args(): diff --git a/Athos/CompilerScripts/grappler.py b/Athos/CompilerScripts/grappler.py index 3edfe6be..5055c8f0 100644 --- a/Athos/CompilerScripts/grappler.py +++ b/Athos/CompilerScripts/grappler.py @@ -61,6 +61,9 @@ def get_white_list(graph): mean_axes_ops = set( i.inputs[1].op.name for i in graph.get_operations() if i.type == "Mean" ) + sum_axes_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "Sum" + ) split_dim_ops = set( i.inputs[0].op.name for i in graph.get_operations() if i.type == "Split" ) @@ -69,14 +72,24 @@ def get_white_list(graph): for i in graph.get_operations() if i.type == "ConcatV2" or i.type == "Concat" ) + argmax_axes_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "ArgMax" + ) + divisor_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type in ["FloorDiv", "RealDiv"] + ) + white_list = ( transp_perm_ops | padding_ops | slice_begin_ops | slice_size_ops | mean_axes_ops + | sum_axes_ops | split_dim_ops | concat_axes_ops + | argmax_axes_ops + | divisor_ops ) return list(white_list) diff --git a/Athos/CompilerScripts/parse_config.py b/Athos/CompilerScripts/parse_config.py index 64edfd35..53300a1f 100644 --- a/Athos/CompilerScripts/parse_config.py +++ b/Athos/CompilerScripts/parse_config.py @@ -115,7 +115,7 @@ def get_shape_list(shape_string): if shape_string == "": return shape for i in shape_string.split(","): - assert i.isnumeric(), "Given input shape has non-integer values" + assert i.isnumeric(), "Given input shape has non-integer value : {}".format(i) shape.append(int(i)) return shape From 818b8c2ca4dc7c78f7831cb1e530e363f9b66097 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Mon, 11 Jan 2021 02:43:49 +0530 Subject: [PATCH 27/72] Handle TransformGraph for cases where constant outputs where converted to vars. If the graph has a constant output, it will get converted to a variable. While dumping graph_defs TransformGraph needs to be able to find that output. So we teach it to find the newly created variable. --- Athos/TFCompiler/DumpTFMtData.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/Athos/TFCompiler/DumpTFMtData.py b/Athos/TFCompiler/DumpTFMtData.py index 35982958..40b151e0 100644 --- a/Athos/TFCompiler/DumpTFMtData.py +++ b/Athos/TFCompiler/DumpTFMtData.py @@ -31,7 +31,21 @@ def strip_variable_init_constants(graph_def, input_tensor_names, output_tensor_n 'remove_nodes(op=Identity)', 'strip_unused_nodes', ] - optimized_graph_def = TransformGraph(graph_def, input_tensor_names, output_tensor_names, transforms) + # Sanity check if output/input nodes were constant and replaced with variables. + all_node_names = set([i.name for i in graph_def.node]) + def get_true_names(tensor_names, all_nodes): + real_names = [] + for i in tensor_names: + if i not in all_nodes: + var_name = i + "_mpc_const_var" + if var_name in all_nodes: + real_names.append(var_name) + else: + real_names.append(i) + return real_names + real_input_names = get_true_names(input_tensor_names, all_node_names) + real_output_names = get_true_names(output_tensor_names, all_node_names) + optimized_graph_def = TransformGraph(graph_def, real_input_names, real_output_names, transforms) return optimized_graph_def def save_graphdef(graph_def): @@ -179,6 +193,8 @@ def save_weights(optimized_graph_def, sess, feed_dict, filename, scaling_factor) values = sess.run(graph_vars, feed_dict) with open(filename, "w") as ff: for val in values: + if val.shape == (0,): #Empty array, nothing to dump. + continue for xx in numpy.nditer(val, order="C"): ff.write(str(int(xx * (1 << scaling_factor))) + " ") ff.write("\n") \ No newline at end of file From 5420ee02fc28ec80d04fb3f81da274adc40a4d31 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Mon, 11 Jan 2021 02:47:02 +0530 Subject: [PATCH 28/72] Typo fix in seedot --- Athos/SeeDot/IR/IR.py | 2 +- Athos/SeeDot/IR/IRBuilderCSF.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Athos/SeeDot/IR/IR.py b/Athos/SeeDot/IR/IR.py index b08db5f3..50362d0e 100644 --- a/Athos/SeeDot/IR/IR.py +++ b/Athos/SeeDot/IR/IR.py @@ -115,7 +115,7 @@ def subst(self, from_idf:str, to_e:Expr): class IntBop(IntExpr): def __init__(self, e1:IntExpr, op:Op.Op, e2:IntExpr): - assert(op in Op.Op.op_list('+ - * / << >> & | ^')) + assert(op in Op.Op.op_list('+ - * / << >> & | ^ ==')) self.e1 = e1 self.op = op self.e2 = e2 diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index de0cf2f5..a8279ae6 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -942,7 +942,7 @@ def visitBopConvTranspose(self, node:AST.BOp, args=None): assert(AST.Operators.findConvOutputImgSize(d_prime_tilde, pad_d_tr_total, FD, stride_d_tr) == D) returnExpr = self.getTempVar() - comment = IR.Comment(expr_1.idf + ' #T ' + expr2.idf + ', convDim = ' + str(convDim)) + comment = IR.Comment(expr_1.idf + ' #T ' + expr_2.idf + ', convDim = ' + str(convDim)) funcCallArgsDict = OrderedDict() funcCallArgsDict[IR.Int(N, 32)] = "N" if convDim==3: From 6999cc02f727db40122334461fad8fe2c6c9b55e Mon Sep 17 00:00:00 2001 From: Bhatu Date: Mon, 11 Jan 2021 02:47:42 +0530 Subject: [PATCH 29/72] Fix CreateTensor signature for 64 bit case --- Athos/TFEzPCLibrary/Library64_common.ezpc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Athos/TFEzPCLibrary/Library64_common.ezpc b/Athos/TFEzPCLibrary/Library64_common.ezpc index 395ccc89..d276e2ad 100644 --- a/Athos/TFEzPCLibrary/Library64_common.ezpc +++ b/Athos/TFEzPCLibrary/Library64_common.ezpc @@ -136,13 +136,13 @@ def void MatAdd5(int32_pl a1, int32_pl a2, int32_pl a3, int32_pl a4, int32_pl a5 }; } (**************************) -def void CreateTensor1(int32_pl s1, int32_pl val, int32_pl[s1] arr){ +def void CreateTensor1(int32_pl s1, int64_pl val, int64_pl[s1] arr){ for i1=[0:s1]{ arr[i1] = val; }; } -def void CreateTensor2(int32_pl s1, int32_pl s2, int32_pl val, int32_pl[s1][s2] arr){ +def void CreateTensor2(int32_pl s1, int32_pl s2, int64_pl val, int64_pl[s1][s2] arr){ for i1=[0:s1]{ for i2=[0:s2]{ arr[i1][i2] = val; @@ -150,7 +150,7 @@ def void CreateTensor2(int32_pl s1, int32_pl s2, int32_pl val, int32_pl[s1][s2] }; } -def void CreateTensor3(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl val, int32_pl[s1][s2][s3] arr){ +def void CreateTensor3(int32_pl s1, int32_pl s2, int32_pl s3, int64_pl val, int64_pl[s1][s2][s3] arr){ for i1=[0:s1]{ for i2=[0:s2]{ for i3=[0:s3]{ @@ -160,7 +160,7 @@ def void CreateTensor3(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl val, int3 }; } -def void CreateTensor4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl val, int32_pl[s1][s2][s3][s4] arr){ +def void CreateTensor4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64_pl val, int64_pl[s1][s2][s3][s4] arr){ for i1=[0:s1]{ for i2=[0:s2]{ for i3=[0:s3]{ @@ -172,7 +172,7 @@ def void CreateTensor4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32 }; } -def void CreateTensor5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl val, int32_pl[s1][s2][s3][s4][s5] arr){ +def void CreateTensor5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_pl val, int64_pl[s1][s2][s3][s4][s5] arr){ for i1=[0:s1]{ for i2=[0:s2]{ for i3=[0:s3]{ From af13174c5340d28ab59fb81abd717c6cf7473f08 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Mon, 11 Jan 2021 02:55:48 +0530 Subject: [PATCH 30/72] Add unittests and related scripts To run tests, navigate to Athos/tests directory (or provide path to pytest): 1. Run all tests with target as CPP. Backend can be CPP,3PC,2PC_HE,2PC_OT pytest -rs . --backend="CPP" 2. Run a specific test. pytest -rs . -k "test_arith_binop" --backend="CPP" 3. Run and generate a coverage report pytest --cov --cov-report html --cov-config=pytest_coverage_tf.config . Install pytest and pytest-cov to run the above commands. --- Athos/tests/conftest.py | 65 ++++ Athos/tests/pytest.ini | 2 + Athos/tests/pytest_coverage_tf.config | 12 + Athos/tests/tf/unittests/test_arith_binops.py | 175 +++++++++++ Athos/tests/tf/unittests/test_batchnorm.py | 48 +++ Athos/tests/tf/unittests/test_convolution.py | 90 ++++++ Athos/tests/tf/unittests/test_non_linear.py | 58 ++++ .../tf/unittests/test_shape_manipulation.py | 297 ++++++++++++++++++ Athos/tests/tf/unittests/test_unaryops.py | 194 ++++++++++++ Athos/tests/utils.py | 199 ++++++++++++ 10 files changed, 1140 insertions(+) create mode 100644 Athos/tests/conftest.py create mode 100644 Athos/tests/pytest.ini create mode 100644 Athos/tests/pytest_coverage_tf.config create mode 100644 Athos/tests/tf/unittests/test_arith_binops.py create mode 100644 Athos/tests/tf/unittests/test_batchnorm.py create mode 100644 Athos/tests/tf/unittests/test_convolution.py create mode 100644 Athos/tests/tf/unittests/test_non_linear.py create mode 100644 Athos/tests/tf/unittests/test_shape_manipulation.py create mode 100644 Athos/tests/tf/unittests/test_unaryops.py create mode 100644 Athos/tests/utils.py diff --git a/Athos/tests/conftest.py b/Athos/tests/conftest.py new file mode 100644 index 00000000..38bf6157 --- /dev/null +++ b/Athos/tests/conftest.py @@ -0,0 +1,65 @@ +import pytest +import tempfile +import shutil +import os + + +def pytest_addoption(parser): + parser.addoption( + "--backend", + action="store", + default="CPP", + help="backend : CPP | 2PC_HE | 2PC_OT | 3PC", + ) + + +@pytest.fixture(scope="session") +def backend(request): + opt = request.config.getoption("--backend") + if opt not in ["CPP", "3PC", "2PC_HE", "2PC_OT"]: + opt = "CPP" + return opt + + +@pytest.fixture(scope="session", autouse=True) +def test_env(): + config = {} + test_dir = "cryptflow_tests" + path = os.path.join(tempfile.gettempdir(), test_dir) + if os.path.exists(path): + shutil.rmtree(path, ignore_errors=True) + os.mkdir(path) + config["test_dir"] = path + return config + + +def make_dir(path): + if os.path.exists(path): + shutil.rmtree(path, ignore_errors=True) + else: + os.mkdir(path) + return + + +# Hook to check if test failed +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_makereport(item, call): + # execute all other hooks to obtain the report object + outcome = yield + rep = outcome.get_result() + # set a report attribute for each phase of a call, which can + # be "setup", "call", "teardown" + setattr(item, "rep_" + rep.when, rep) + + +@pytest.fixture +def test_dir(request, test_env): + test_name = request.node.name[len("test_") :] + main_test_dir = test_env["test_dir"] + test_dir = os.path.join(main_test_dir, "athos_test_" + test_name) + make_dir(test_dir) + yield test_dir + # Remove dir only if test passed + if not request.node.rep_call.failed: + shutil.rmtree(test_dir, ignore_errors=True) + return diff --git a/Athos/tests/pytest.ini b/Athos/tests/pytest.ini new file mode 100644 index 00000000..1ceab942 --- /dev/null +++ b/Athos/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = -p no:warnings diff --git a/Athos/tests/pytest_coverage_tf.config b/Athos/tests/pytest_coverage_tf.config new file mode 100644 index 00000000..e896bccf --- /dev/null +++ b/Athos/tests/pytest_coverage_tf.config @@ -0,0 +1,12 @@ +[run] +branch = True +source = + ../SeeDot + ../TFCompiler + +[report] +exclude_lines = + if __name__ == .__main__.: + +[html] +directory = coverage_html_report diff --git a/Athos/tests/tf/unittests/test_arith_binops.py b/Athos/tests/tf/unittests/test_arith_binops.py new file mode 100644 index 00000000..7ad5bd41 --- /dev/null +++ b/Athos/tests/tf/unittests/test_arith_binops.py @@ -0,0 +1,175 @@ +import tensorflow as tf +import numpy as np + +import pytest + +import sys +import os + +# Athos DIR +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +from tests.utils import Config, Compiler, assert_almost_equal + +@pytest.mark.parametrize( + "a_shape,b_shape,dtype", + [ + ((4, 4), (4, 4), np.single), # Normal + ((2, 2), (1,), np.single), # Broadcasting + ((3, 1, 2, 1), (2, 1, 4), np.single), # Broadcasting + ((2, 2), (), np.single), # Constant + ], +) +@pytest.mark.parametrize( + "tfOp", [tf.math.add, tf.math.subtract, tf.math.multiply, tf.raw_ops.AddV2] +) +def test_arith_binop(test_dir, backend, tfOp, a_shape, b_shape, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + b_inp = dtype(np.random.randn(*b_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + b = tf.constant(b_inp, name="b") + output = tfOp(x=a, y=b, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + +@pytest.mark.parametrize( + "a_shape, b_shape, data_format, dtype", + [ + ([4, 1, 4], [4], None, np.single), # Normal + ([4, 1, 4], [4], 'N..C', np.single), # Same as above + pytest.param([4, 4, 1], [4], 'NC..', np.single, marks=pytest.mark.skip(reason="[bias_add] NC.. not supported")), # Normal + ], +) +def test_bias_add(test_dir, backend, a_shape, b_shape, data_format, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + b_inp = dtype(np.random.randn(*b_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + b = tf.constant(b_inp, name="b") + output = tf.nn.bias_add(value=a, bias=b, data_format=data_format, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +@pytest.mark.parametrize("dtype", [np.single]) +@pytest.mark.parametrize( + "tfOp, a_val, divisor", + [ + pytest.param(tf.divide, [7, -7], 5, marks=pytest.mark.skip(reason="[divide] Support for parsing DOUBLES")), # [1, -2] + (tf.divide, [7.0, -7.0], 5.0), # [1.4, -1.4] + pytest.param(tf.truediv, [7, -7], 5, marks=pytest.mark.skip(reason="[divide] Support for parsing DOUBLES")), # [1.4, -1.4] + (tf.truediv, [7.0], 5.0), # [1.4] + (tf.divide, 7.0, 5.0), # 1.4 + pytest.param(tf.floordiv, [7, -7], 5, marks=pytest.mark.skip(reason="[divide] Add support for converting div by constant into a mul")), # [1, -2] + pytest.param(tf.floordiv, [7.0, -7.0], 5.0, marks=pytest.mark.skip(reason="[divide] Add support for converting div by constant into a mul")), # [1.0, -2.0] + pytest.param(tf.truncatediv, -7, 5, marks=pytest.mark.skip(reason="[divide] Truncated div not supported")), # -1 + ], +) +def test_div(test_dir, backend, tfOp, a_val, divisor, dtype): + graph = tf.Graph() + a_inp = np.array(a_val) + with graph.as_default(): + b = tf.constant(divisor, name="b") + a = tf.compat.v1.placeholder(tf.as_dtype(b.dtype), shape=a_inp.shape, name="a") + output = tfOp(a, b, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +@pytest.mark.parametrize( + "a_shape, b_shape, transpose_a, transpose_b, bisModel", + [ + ([3, 2], [2, 3], False, False, True), + pytest.param( + [3, 2], + [2, 3], + False, + False, + False, + marks=pytest.mark.skip( + reason="[matmul] expect atleast one param to belong to model" + ), + ), + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +def test_matmul( + test_dir, backend, a_shape, b_shape, transpose_a, transpose_b, bisModel, dtype +): + if backend == "2PC_HE": + pytest.skip( + "Assertion error in 2PC_HE FCField::matrix_multiplication Assertion `num_cols == 1' failed." + ) + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + b_inp = dtype(np.random.randn(*b_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + if bisModel: + b = tf.constant(b_inp, name="b") + else: + b = tf.compat.v1.placeholder( + tf.as_dtype(dtype), shape=b_inp.shape, name="b" + ) + output = tf.matmul(a, b, transpose_a, transpose_b, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + feed_dict = {a: a_inp} + if not bisModel: + feed_dict[b] = b_inp + expected_output = sess.run(output, feed_dict=feed_dict) + config = Config(backend).add_input(a).add_output(output) + if not bisModel: + config.add_input(b) + config.config["scale"] = 12 + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +@pytest.mark.parametrize( + "a, b", + [ + ([1.2, 1.3], [1.2, 1.3]), + ([1.2, 1.3], [1.2, 1.2]), + ([1.2, 1.3], [1.2]), + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +@pytest.mark.skip(reason="[equal] Not able to cast boolean to int ezpc") +def test_equal(test_dir, backend, a, b, dtype): + graph = tf.Graph() + a_inp = dtype(np.array(a)) + b_inp = dtype(np.array(b)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + b = tf.constant(b_inp, name="b") + output = tf.math.equal(a, b, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return diff --git a/Athos/tests/tf/unittests/test_batchnorm.py b/Athos/tests/tf/unittests/test_batchnorm.py new file mode 100644 index 00000000..5b6f393d --- /dev/null +++ b/Athos/tests/tf/unittests/test_batchnorm.py @@ -0,0 +1,48 @@ +import tensorflow as tf +import numpy as np + +import pytest + +import sys +import os + +# Athos DIR +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +from tests.utils import Config, Compiler, assert_almost_equal + + +@pytest.mark.parametrize( + "a_shape, scale, offset, mean, variance", + [([1, 2, 2, 1], [1.5], [2.3], [0.5], [0.2]), + #([1], 1.5, 2.3, 0.5, 0.2), ([], 1.5, 2.3, 0.5, 0.2) + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +@pytest.mark.parametrize( + "tfOp", [tf.raw_ops.FusedBatchNorm] +) +@pytest.mark.skip(reason="[batch_norm] Test not complete") +def test_fused_batch_norm( + test_dir, backend, tfOp, a_shape, scale, offset, mean, variance, dtype +): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + output = tfOp( + x=a, + scale=scale, + offset=offset, + mean=mean, + variance=variance, + is_training=False, + name="output", + ) + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + assert expected_output is not None + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return \ No newline at end of file diff --git a/Athos/tests/tf/unittests/test_convolution.py b/Athos/tests/tf/unittests/test_convolution.py new file mode 100644 index 00000000..1cb24479 --- /dev/null +++ b/Athos/tests/tf/unittests/test_convolution.py @@ -0,0 +1,90 @@ +import tensorflow as tf +import numpy as np + +import pytest + +import sys +import os + +# Athos DIR +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +from tests.utils import Config, Compiler, assert_almost_equal + + +@pytest.mark.parametrize( + "tfOp, a_shape, kernel_shape, strides, padding", + [ + (tf.nn.conv2d, [1, 5, 5, 1], [2, 2, 1, 2], [1, 1, 1, 1], "SAME"), + (tf.nn.conv2d, [1, 5, 5, 1], [2, 2, 1, 2], [1, 1, 1, 1], "VALID"), + (tf.nn.conv3d, [1, 5, 5, 5, 1], [2, 2, 2, 1, 2], [1, 1, 1, 1, 1], "SAME"), + (tf.nn.conv3d, [1, 5, 5, 5, 1], [2, 2, 2, 1, 2], [1, 1, 1, 1, 1], "VALID"), + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +def test_conv(test_dir, backend, tfOp, a_shape, kernel_shape, strides, padding, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + kernel_inp = dtype(np.random.randn(*kernel_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + filters = tf.constant(kernel_inp, name="filter") + output = tfOp(a, filters, strides, padding, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +@pytest.mark.parametrize( + "tfOp, a_shape, kernel_shape, output_shape, strides, padding", + [ + ( + tf.nn.conv3d_transpose, + [1, 4, 4, 4, 2], + [2, 2, 2, 1, 2], + [1, 5, 5, 5, 1], + [1, 1, 1, 1, 1], + "VALID", + ), + pytest.param( + tf.nn.conv3d_transpose, + [1, 5, 5, 5, 2], + [2, 2, 2, 1, 2], + [1, 5, 5, 5, 1], + [1, 1, 1, 1, 1], + "SAME", + marks=pytest.mark.skip(reason="[conv3d_transpose] SAME padding bug"), + ), + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +def test_conv_transpose( + test_dir, + backend, + tfOp, + a_shape, + kernel_shape, + output_shape, + strides, + padding, + dtype, +): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + kernel_inp = dtype(np.random.randn(*kernel_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + filters = tf.constant(kernel_inp, name="filter") + output = tfOp(a, filters, output_shape, strides, padding, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return diff --git a/Athos/tests/tf/unittests/test_non_linear.py b/Athos/tests/tf/unittests/test_non_linear.py new file mode 100644 index 00000000..356fe6e5 --- /dev/null +++ b/Athos/tests/tf/unittests/test_non_linear.py @@ -0,0 +1,58 @@ +import tensorflow as tf +import numpy as np + +import pytest + +import sys +import os + +# Athos DIR +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +from tests.utils import Config, Compiler, assert_almost_equal + + +@pytest.mark.skip(reason="[non-linear] Haven't made non-linear functionalities public") +@pytest.mark.parametrize("a_shape", [(4, 4), (1,), ()]) +@pytest.mark.parametrize("dtype", [np.single]) +@pytest.mark.parametrize( + "tfOp", + [ + tf.math.sqrt, + tf.math.rsqrt, + tf.math.sigmoid, + tf.math.tanh, + tf.nn.relu, + ], +) +def test_non_linear(test_dir, backend, tfOp, a_shape, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + output = tfOp(a, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + assert expected_output is not None + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + +@pytest.mark.skip(reason="[softmax] Haven't made non-linear functionalities public") +@pytest.mark.parametrize("a_shape, axis", [((2, 3), 1), ((1,), 0)]) +@pytest.mark.parametrize("dtype", [np.single]) +def test_softmax(test_dir, backend, a_shape, axis, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + output = tf.nn.softmax(a, axis=axis, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + assert expected_output is not None + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return diff --git a/Athos/tests/tf/unittests/test_shape_manipulation.py b/Athos/tests/tf/unittests/test_shape_manipulation.py new file mode 100644 index 00000000..bba9e3aa --- /dev/null +++ b/Athos/tests/tf/unittests/test_shape_manipulation.py @@ -0,0 +1,297 @@ +import tensorflow as tf +import numpy as np + +import pytest + +import sys +import os + +# Athos DIR +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +from tests.utils import Config, Compiler, assert_almost_equal + + +@pytest.mark.parametrize( + "a_shape, out_shape", + [ + ([2, 3], [6]), + ([6], [2, 3]), + ([2, 3], [3, 2]), + ([2, 3], [-1]), # Flatten 1-D + pytest.param( + [1], [], marks=pytest.mark.skip(reason="[reshape] dumping weights error") + ), # convert to scalar + ([3, 2, 3], [2, -1]), # infer -1 as 9 + ([3, 2, 3], [-1, 9]), # infer -1 as 2 + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +def test_reshape(test_dir, backend, a_shape, out_shape, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + output = tf.reshape(a, out_shape, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + assert expected_output is not None + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +@pytest.mark.parametrize( + "a_shape, perm", + [([2, 3], [1, 0]), ([2, 4, 3], [0, 2, 1])], # normal transpose, with perm +) +@pytest.mark.parametrize("dtype", [np.single]) +def test_transpose(test_dir, backend, a_shape, perm, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + output = tf.transpose(a, perm, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +@pytest.mark.parametrize( + "a_shape, num_or_size_splits, axis", + [ + ([2, 10], 5, 1), + pytest.param( + [5, 7], + [1, 4, 2], + 1, + marks=pytest.mark.skip( + reason="[split] don't support split into specific sizes (SplitV)" + ), + ), + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +def test_split(test_dir, backend, a_shape, num_or_size_splits, axis, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + output = tf.split(a, num_or_size_splits, axis, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + if type(output) == list: + tf_output = output[-1] + tf_expected_output = expected_output[-1] + else: + tf_output = output + tf_expected_output = expected_output + config = Config(backend).add_input(a).add_output(tf_output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal( + tf_output=tf_expected_output, mpc_tensor=mpc_output, precision=2 + ) + return + + +# Squeeze +# TODO: also add a squeeze dim example. +@pytest.mark.parametrize( + "a_shape, axis", + [ + pytest.param( + [1, 2, 1, 3, 1, 1], + None, + marks=pytest.mark.skip(reason="[squeeze] Parametric squeeze not supported"), + ), + pytest.param( + [1, 2, 1, 3, 1, 1], + [2, 4], + marks=pytest.mark.skip(reason="[squeeze] Parametric squeeze not supported"), + ), + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +def test_squeeze(test_dir, backend, a_shape, axis, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + output = tf.squeeze(a, axis=axis, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +@pytest.mark.parametrize( + "a_shape, begin, size", + [ + ([3, 2, 3], [1, 0, 0], [1, 1, 3]), + ([3, 2, 3], [1, 0, 0], [1, 2, 3]), + ([3, 2, 3], [1, 0, 0], [2, 1, 3]), + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +def test_slice(test_dir, backend, a_shape, begin, size, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + output = tf.slice(a, begin, size, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +@pytest.mark.parametrize( + "a_shape, b_shape, axis", + [ + ([2, 3], [3, 3], 0), + ([2, 3, 2, 1], [2, 6, 2, 1], 1), + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +def test_concat(test_dir, backend, a_shape, b_shape, axis, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + b_inp = dtype(np.random.randn(*b_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + b = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=b_inp.shape, name="b") + output = tf.concat([a, b], axis, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp, b: b_inp}) + + config = Config(backend).add_input(a).add_input(b).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp, b_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +# ExpandDims +@pytest.mark.parametrize( + "a_shape, axis", + [ + pytest.param( + [3, 2, 3], 1, marks=pytest.mark.skip(reason="[expand_dims] not supported") + ), + pytest.param( + [2, 5], 0, marks=pytest.mark.skip(reason="[expand_dims] not supported") + ), + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +def test_expand_dims(test_dir, backend, a_shape, axis, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + output = tf.expand_dims(a, axis, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +# Pad +@pytest.mark.parametrize( + "a_shape, paddings, mode, constant_values", + [ + ([1, 2, 2, 1], [[1, 1], [1, 2], [1, 1], [1, 3]], "CONSTANT", 0), + pytest.param( + [1, 2, 2, 1], + [[1, 1], [1, 2], [1, 1], [1, 3]], + "REFLECT", + 0, + marks=pytest.mark.skip(reason="[pad] REFLECT not supported"), + ), + pytest.param( + [1, 2, 2, 1], + [[1, 1], [1, 2], [1, 1], [1, 3]], + "SYMMETRIC", + 0, + marks=pytest.mark.skip(reason="[pad] SYMMETRIC not supported"), + ), + pytest.param( + [2, 3], + [ + [1, 1], + [2, 2], + ], + "CONSTANT", + 0, + marks=pytest.mark.skip(reason="[pad] Generic pad not supported"), + ), + pytest.param( + [1, 2, 2, 1], + [[1, 1], [1, 2], [1, 1], [1, 3]], + "CONSTANT", + 1.2, + marks=pytest.mark.skip(reason="[pad] non-zero padding not supported"), + ), + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +def test_pad(test_dir, backend, a_shape, paddings, mode, constant_values, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + pad = tf.constant(paddings, name="paddings") + output = tf.pad( + a, pad, mode=mode, constant_values=constant_values, name="output" + ) + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +# Tile +@pytest.mark.parametrize( + "a_shape, multiples", [([2, 3], [1, 2]), ([2, 3], [2, 1]), ([2, 3], [2, 2])] +) +@pytest.mark.parametrize("dtype", [np.single]) +@pytest.mark.skip(reason="[tile] Not supported") +def test_tile(test_dir, backend, a_shape, multiples, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + mults = tf.constant(multiples, name="multiples") + output = tf.tile(a, mults, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return diff --git a/Athos/tests/tf/unittests/test_unaryops.py b/Athos/tests/tf/unittests/test_unaryops.py new file mode 100644 index 00000000..519eb539 --- /dev/null +++ b/Athos/tests/tf/unittests/test_unaryops.py @@ -0,0 +1,194 @@ +import tensorflow as tf +import numpy as np + +import pytest + +import sys +import os + +# Athos DIR +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +from tests.utils import Config, Compiler, assert_almost_equal + + +@pytest.mark.parametrize("a_shape", [[2, 2], []]) +@pytest.mark.parametrize("dtype", [np.single]) +@pytest.mark.parametrize( + "tfOp", + [ + tf.math.square, + tf.math.negative, + pytest.param( + tf.math.floor, + marks=pytest.mark.skip(reason="[floor] Floor1 not implemented"), + ), + tf.shape, + tf.identity, + pytest.param( + tf.zeros_like, marks=pytest.mark.skip(reason="[zeros_like] EzPC issue for inp=[2,2]") + ), + ], +) +def test_uop(test_dir, backend, tfOp, a_shape, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + output = tfOp(a, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +@pytest.mark.parametrize( + "a_shape, axis, keepdims", + [ + ([3, 2], None, False), + ([3, 2], [0, 1], False), + ([3, 2], 0, False), + ([3, 2], 1, False), + ([3, 2], 0, True), + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +@pytest.mark.parametrize("tfOp", [tf.math.reduce_mean, tf.reduce_sum]) +@pytest.mark.skip(reason="[reduce] Reduce mean assert shape failure") +def test_reduce(test_dir, backend, tfOp, a_shape, axis, keepdims, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + output = tfOp(a, axis=axis, keepdims=keepdims, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +@pytest.mark.parametrize( + "a_shape, axis", + [ + ([3, 2], None), + ([3, 2], 0), + ([3, 2], 1), + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +@pytest.mark.skip(reason="[argmax] Generic argmax not implemented") +def test_argmax(test_dir, backend, a_shape, axis, dtype): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + output = tf.math.argmax(a, axis=axis, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +# NHWC is the default format +@pytest.mark.parametrize( + "a_shape, ksize, strides, padding, data_format", + [ + ([1, 5, 5, 1], [1, 2, 2, 1], [1, 2, 2, 1], "VALID", "NHWC"), + pytest.param( + [1, 5, 5, 1], + [1, 2, 2, 1], + [1, 2, 2, 1], + "SAME", + "NHWC", + marks=pytest.mark.skip(reason="[max/avg_pool] Pooling SAME pad bug"), + ), + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +@pytest.mark.parametrize("tfOp", [tf.nn.max_pool, tf.nn.avg_pool]) +def test_pool( + test_dir, backend, tfOp, a_shape, ksize, strides, padding, data_format, dtype +): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + output = tfOp( + a, + ksize=ksize, + strides=strides, + padding=padding, + data_format=data_format, + name="output", + ) + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +# x = tf.constant([1.8, 2.2], dtype=tf.float32) +# tf.dtypes.cast(x, tf.int32) +# Currently cast acts as an identity operation. +@pytest.mark.parametrize("a_shape", [[2, 2]]) +@pytest.mark.parametrize( + "from_dtype, to_dtype", + [ + (np.single, np.single), + ( + np.double, + np.single, + ), + pytest.param( + np.single, + np.int32, + marks=pytest.mark.skip(reason="[cast] Only support identity cast"), + ), + ], +) +def test_cast(test_dir, backend, a_shape, from_dtype, to_dtype): + graph = tf.Graph() + a_inp = from_dtype(np.random.randn(*a_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder( + tf.as_dtype(from_dtype), shape=a_inp.shape, name="a" + ) + output = tf.cast(a, to_dtype, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + +@pytest.mark.parametrize("a_shape, value", [([2, 2], 9.2), ([], 9.2), ([2, 2], 1)]) +def test_fill(test_dir, backend, a_shape, value): + graph = tf.Graph() + with graph.as_default(): + output = tf.fill(a_shape, value) + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output) + + config = Config(backend).add_output(output) + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return \ No newline at end of file diff --git a/Athos/tests/utils.py b/Athos/tests/utils.py new file mode 100644 index 00000000..c678079b --- /dev/null +++ b/Athos/tests/utils.py @@ -0,0 +1,199 @@ +import tempfile +import sys +import os +import shutil +import re + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +import CompilerScripts.parse_config as parse_config +import CompileTFGraph + +import numpy as np +import subprocess +import threading + + +class Config: + def __init__(self, mode): + self.config = { + "model_name": "model.pb", + "scale": 23, + "bitlength": 64, + "save_weights": True, + } + if mode == "CPP": + self.config["target"] = "CPP" + elif mode == "3PC": + self.config["target"] = "PORTHOS" + elif mode == "2PC_OT": + self.config["target"] = "PORTHOS2PC" + self.config["bitlength"] = 41 + self.config["backend"] = "OT" + + elif mode == "2PC_HE": + self.config["target"] = "PORTHOS2PC" + self.config["bitlength"] = 41 + self.config["backend"] = "HE" + else: + assert False, "Mode has to be one of CPP/3PC/2PC_OT/2PC_HE" + + def add_input(self, tensor_op): + input_name = tensor_op.op.name + shape = tensor_op.shape.as_list() + shape_string = ",".join(map(str, shape)) + inputs = self.config.get("input_tensors") + if inputs == None: + self.config["input_tensors"] = {input_name: shape_string} + else: + self.config["input_tensors"][input_name] = shape_string + return self + + def add_output(self, tensor_op): + output_name = tensor_op.op.name + outputs = self.config.get("output_tensors") + if outputs == None: + self.config["output_tensors"] = [output_name] + else: + self.config["output_tensors"].append(output_name) + return self + + +def get_params(config): + return parse_config.parse_config(config) + + +def make_dir(path): + if os.path.exists(path): + shutil.rmtree(path, ignore_errors=True) + else: + os.mkdir(path) + return + + +def save_graph(graph_def, config, test_dir): + fname = config["model_name"] + fpath = os.path.join(test_dir, fname) + with open(fpath, "wb") as f: + f.write(graph_def.SerializeToString()) + print("\n\nfile name: ", f.name, "\n\n\n") + config["model_name"] = fpath + return + + +def convert_raw_output_to_np(filename, bitlength, scale): + matcher = re.compile(r"[-]?[0-9]+") + scaled_array = [] + with open(filename, "r") as f: + for line in f: + match = matcher.fullmatch(line.rstrip()) + if match: + unsigned_number = int(match.group(0)) + number = ( + unsigned_number + if (unsigned_number < 2 ** (bitlength - 1)) + else unsigned_number - 2 ** bitlength + ) + scaled_array.append(float(number) / (2 ** scale)) + return np.array(scaled_array) + + +class Program: + def __init__(self, program_path, model_weight_path, params, test_dir): + self.program_path = program_path + self.model_weight_path = model_weight_path + self.scale = params["scale"] + self.bitlength = params["bitlength"] + self.target = params["target"] + self.test_dir = test_dir + + def run(self, inputs): + # scale input and dump to file + inputs_scaled = os.path.join( + self.test_dir, "input_fixedpt_scale_" + str(self.scale) + ".inp" + ) + with open(inputs_scaled, "w") as ff: + for i in inputs: + for xx in np.nditer(i, order="C"): + ff.write(str(int(xx * (1 << self.scale))) + " ") + ff.write("\n") + raw_output = os.path.join(self.test_dir, "raw_output") + if self.target == "CPP": + os.system( + "cat {inputs} {weights} | {program} > {output}".format( + program=self.program_path, + inputs=inputs_scaled, + weights=self.model_weight_path, + output=raw_output, + ) + ) + elif self.target == "PORTHOS": + util_dir = os.path.dirname(os.path.abspath(__file__)) + porthos_dir = os.path.join(util_dir, "..", "..", "Porthos") + ip_addr = os.path.join(porthos_dir, "files", "addresses") + keys_dir = os.path.join(porthos_dir, "files", "keys") + client_cmd = ( + "{program} 0 {ip_addr_file} {keys_dir} < {input} > {output}".format( + program=self.program_path, + ip_addr_file=ip_addr, + input=inputs_scaled, + output=raw_output, + keys_dir=keys_dir, + ) + ) + server_cmd = "{program} 1 {ip_addr_file} {keys_dir} < {input}".format( + program=self.program_path, + ip_addr_file=ip_addr, + input=self.model_weight_path, + keys_dir=keys_dir, + ) + party2_cmd = "{program} 2 {ip_addr_file} {keys_dir}".format( + program=self.program_path, ip_addr_file=ip_addr, keys_dir=keys_dir + ) + commands = [client_cmd, server_cmd, party2_cmd] + procs = [subprocess.Popen(i, shell=True) for i in commands] + for p in procs: + p.wait() + elif self.target == "PORTHOS2PC": + util_dir = os.path.dirname(os.path.abspath(__file__)) + sci_dir = os.path.join(util_dir, "..", "..", "SCI") + port = 1234 + client_cmd = "{program} r=2 p={port} < {input} > {output}".format( + program=self.program_path, + port=port, + input=inputs_scaled, + output=raw_output, + ) + server_cmd = "{program} r=1 p={port} < {input} > /dev/null".format( + program=self.program_path, + port=port, + input=self.model_weight_path, + output=raw_output, + ) + commands = [client_cmd, server_cmd] + procs = [subprocess.Popen(i, shell=True) for i in commands] + for p in procs: + p.wait() + return convert_raw_output_to_np(raw_output, self.bitlength, self.scale) + + +class Compiler: + def __init__(self, graph, config, test_dir): + self.graph_def = graph.as_graph_def() + self.config = config.config + self.test_dir = test_dir + + def compile_and_run(self, inputs): + save_graph(self.graph_def, self.config, self.test_dir) + params = get_params(self.config) + print(params) + (output_program, model_weight_file) = CompileTFGraph.generate_code(params) + prog = Program(output_program, model_weight_file, params, self.test_dir) + output = prog.run(inputs) + return output + + +def assert_almost_equal(tf_output, mpc_tensor, precision): + if tf_output.shape == (0,): + return + np.testing.assert_almost_equal(tf_output.flatten(), mpc_tensor, decimal=precision) + return From 055da16cdb4231dc75f4c1cfa26057b30b444a46 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Mon, 11 Jan 2021 03:12:12 +0530 Subject: [PATCH 31/72] Update disclaimer --- Athos/.gitignore | 3 ++- Athos/CompileTFGraph.py | 23 ++++++++++++++++++ .../comparison_scripts/compare_output.py | 23 ++++++++++++++++++ .../comparison_scripts/compare_output.sh | 21 ++++++++++++++++ .../comparison_scripts/convert_scale.py | 23 ++++++++++++++++++ .../comparison_scripts/convert_to_signed.py | 23 ++++++++++++++++++ .../comparison_scripts/convert_to_signed.sh | 20 ++++++++++++++++ Athos/CompilerScripts/compile_tf.py | 23 ++++++++++++++++++ Athos/CompilerScripts/compile_tf_graph.py | 23 ++++++++++++++++++ Athos/CompilerScripts/create_tf_input.py | 23 ++++++++++++++++++ Athos/CompilerScripts/get_pred_tf_graph.py | 23 ++++++++++++++++++ Athos/CompilerScripts/grappler.py | 23 ++++++++++++++++++ Athos/CompilerScripts/parse_config.py | 23 ++++++++++++++++++ .../preprocess_frozen_tf_graph.py | 23 ++++++++++++++++++ Athos/CompilerScripts/tf_graph_io.py | 23 ++++++++++++++++++ Athos/CompilerScripts/tf_graph_trans.py | 23 ++++++++++++++++++ Athos/tests/conftest.py | 23 ++++++++++++++++++ Athos/tests/tf/unittests/test_arith_binops.py | 24 +++++++++++++++++++ Athos/tests/tf/unittests/test_batchnorm.py | 23 ++++++++++++++++++ Athos/tests/tf/unittests/test_convolution.py | 23 ++++++++++++++++++ Athos/tests/tf/unittests/test_non_linear.py | 23 ++++++++++++++++++ .../tf/unittests/test_shape_manipulation.py | 23 ++++++++++++++++++ Athos/tests/tf/unittests/test_unaryops.py | 23 ++++++++++++++++++ Athos/tests/utils.py | 23 ++++++++++++++++++ 24 files changed, 527 insertions(+), 1 deletion(-) diff --git a/Athos/.gitignore b/Athos/.gitignore index 3ec64aa0..c36857d5 100644 --- a/Athos/.gitignore +++ b/Athos/.gitignore @@ -9,4 +9,5 @@ SeeDot/debug/ *__temp1.ezpc *__temp2.ezpc -__pycache__/ \ No newline at end of file +__pycache__/ +tests/debug diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index 8ee57c84..f470557c 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import argparse from argparse import RawTextHelpFormatter diff --git a/Athos/CompilerScripts/comparison_scripts/compare_output.py b/Athos/CompilerScripts/comparison_scripts/compare_output.py index 469fe90f..260093e5 100644 --- a/Athos/CompilerScripts/comparison_scripts/compare_output.py +++ b/Athos/CompilerScripts/comparison_scripts/compare_output.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import numpy as np import sys diff --git a/Athos/CompilerScripts/comparison_scripts/compare_output.sh b/Athos/CompilerScripts/comparison_scripts/compare_output.sh index 9a745f14..dd1516a3 100644 --- a/Athos/CompilerScripts/comparison_scripts/compare_output.sh +++ b/Athos/CompilerScripts/comparison_scripts/compare_output.sh @@ -1,3 +1,24 @@ +# Authors: Pratik Bhatu. + +# Copyright: +# Copyright (c) 2021 Microsoft Research +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + # Usage: tf_output.float(floatingpt) party0_output(fixedpt) BITLEN SCALING_FACTOR PRECISION(upto how many points to compare?) # This first converts unsigned fixedpt to signed SCRIPT_DIR="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" diff --git a/Athos/CompilerScripts/comparison_scripts/convert_scale.py b/Athos/CompilerScripts/comparison_scripts/convert_scale.py index 2211d43b..228428c2 100644 --- a/Athos/CompilerScripts/comparison_scripts/convert_scale.py +++ b/Athos/CompilerScripts/comparison_scripts/convert_scale.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import sys if __name__ == '__main__': assert(len(sys.argv) == 4) diff --git a/Athos/CompilerScripts/comparison_scripts/convert_to_signed.py b/Athos/CompilerScripts/comparison_scripts/convert_to_signed.py index cd3524f8..6b9808e5 100644 --- a/Athos/CompilerScripts/comparison_scripts/convert_to_signed.py +++ b/Athos/CompilerScripts/comparison_scripts/convert_to_signed.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import sys if __name__ == "__main__": diff --git a/Athos/CompilerScripts/comparison_scripts/convert_to_signed.sh b/Athos/CompilerScripts/comparison_scripts/convert_to_signed.sh index 867c2b92..94ac8c2b 100644 --- a/Athos/CompilerScripts/comparison_scripts/convert_to_signed.sh +++ b/Athos/CompilerScripts/comparison_scripts/convert_to_signed.sh @@ -1,3 +1,23 @@ +# Authors: Pratik Bhatu. + +# Copyright: +# Copyright (c) 2021 Microsoft Research +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + SCRIPT_DIR="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" inp1=$1 bitlen=$2 diff --git a/Athos/CompilerScripts/compile_tf.py b/Athos/CompilerScripts/compile_tf.py index 06ac70ac..2e747f83 100644 --- a/Athos/CompilerScripts/compile_tf.py +++ b/Athos/CompilerScripts/compile_tf.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import argparse import os.path import json diff --git a/Athos/CompilerScripts/compile_tf_graph.py b/Athos/CompilerScripts/compile_tf_graph.py index c3f3727a..7e0d7187 100644 --- a/Athos/CompilerScripts/compile_tf_graph.py +++ b/Athos/CompilerScripts/compile_tf_graph.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import tensorflow as tf import numpy as np import argparse diff --git a/Athos/CompilerScripts/create_tf_input.py b/Athos/CompilerScripts/create_tf_input.py index f4138a5b..064bb3ab 100644 --- a/Athos/CompilerScripts/create_tf_input.py +++ b/Athos/CompilerScripts/create_tf_input.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import argparse import os import sys diff --git a/Athos/CompilerScripts/get_pred_tf_graph.py b/Athos/CompilerScripts/get_pred_tf_graph.py index 337cd8d4..00c204ed 100644 --- a/Athos/CompilerScripts/get_pred_tf_graph.py +++ b/Athos/CompilerScripts/get_pred_tf_graph.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import tensorflow as tf import numpy as np import argparse diff --git a/Athos/CompilerScripts/grappler.py b/Athos/CompilerScripts/grappler.py index 5055c8f0..65b2faa5 100644 --- a/Athos/CompilerScripts/grappler.py +++ b/Athos/CompilerScripts/grappler.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import os os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" diff --git a/Athos/CompilerScripts/parse_config.py b/Athos/CompilerScripts/parse_config.py index 53300a1f..a6189ed6 100644 --- a/Athos/CompilerScripts/parse_config.py +++ b/Athos/CompilerScripts/parse_config.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import argparse import os.path import json diff --git a/Athos/CompilerScripts/preprocess_frozen_tf_graph.py b/Athos/CompilerScripts/preprocess_frozen_tf_graph.py index d73705ac..c2d7cb73 100644 --- a/Athos/CompilerScripts/preprocess_frozen_tf_graph.py +++ b/Athos/CompilerScripts/preprocess_frozen_tf_graph.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' from tf_graph_io import * from tf_graph_trans import * import sys diff --git a/Athos/CompilerScripts/tf_graph_io.py b/Athos/CompilerScripts/tf_graph_io.py index c0bc0507..c99d97d6 100644 --- a/Athos/CompilerScripts/tf_graph_io.py +++ b/Athos/CompilerScripts/tf_graph_io.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import tensorflow as tf from tensorflow.python.platform import gfile diff --git a/Athos/CompilerScripts/tf_graph_trans.py b/Athos/CompilerScripts/tf_graph_trans.py index 5d53edd3..6cd90c2e 100644 --- a/Athos/CompilerScripts/tf_graph_trans.py +++ b/Athos/CompilerScripts/tf_graph_trans.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import tensorflow as tf import tensorflow.contrib.graph_editor as ge from tensorflow.tools.graph_transforms import TransformGraph diff --git a/Athos/tests/conftest.py b/Athos/tests/conftest.py index 38bf6157..52f60356 100644 --- a/Athos/tests/conftest.py +++ b/Athos/tests/conftest.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import pytest import tempfile import shutil diff --git a/Athos/tests/tf/unittests/test_arith_binops.py b/Athos/tests/tf/unittests/test_arith_binops.py index 7ad5bd41..3f7c2dfe 100644 --- a/Athos/tests/tf/unittests/test_arith_binops.py +++ b/Athos/tests/tf/unittests/test_arith_binops.py @@ -1,3 +1,27 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' + import tensorflow as tf import numpy as np diff --git a/Athos/tests/tf/unittests/test_batchnorm.py b/Athos/tests/tf/unittests/test_batchnorm.py index 5b6f393d..9b0e611e 100644 --- a/Athos/tests/tf/unittests/test_batchnorm.py +++ b/Athos/tests/tf/unittests/test_batchnorm.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import tensorflow as tf import numpy as np diff --git a/Athos/tests/tf/unittests/test_convolution.py b/Athos/tests/tf/unittests/test_convolution.py index 1cb24479..0258c677 100644 --- a/Athos/tests/tf/unittests/test_convolution.py +++ b/Athos/tests/tf/unittests/test_convolution.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import tensorflow as tf import numpy as np diff --git a/Athos/tests/tf/unittests/test_non_linear.py b/Athos/tests/tf/unittests/test_non_linear.py index 356fe6e5..5fabae41 100644 --- a/Athos/tests/tf/unittests/test_non_linear.py +++ b/Athos/tests/tf/unittests/test_non_linear.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import tensorflow as tf import numpy as np diff --git a/Athos/tests/tf/unittests/test_shape_manipulation.py b/Athos/tests/tf/unittests/test_shape_manipulation.py index bba9e3aa..1f73113a 100644 --- a/Athos/tests/tf/unittests/test_shape_manipulation.py +++ b/Athos/tests/tf/unittests/test_shape_manipulation.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import tensorflow as tf import numpy as np diff --git a/Athos/tests/tf/unittests/test_unaryops.py b/Athos/tests/tf/unittests/test_unaryops.py index 519eb539..8e0b8c15 100644 --- a/Athos/tests/tf/unittests/test_unaryops.py +++ b/Athos/tests/tf/unittests/test_unaryops.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import tensorflow as tf import numpy as np diff --git a/Athos/tests/utils.py b/Athos/tests/utils.py index c678079b..80e27e3f 100644 --- a/Athos/tests/utils.py +++ b/Athos/tests/utils.py @@ -1,3 +1,26 @@ +''' + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' import tempfile import sys import os From cd3356cd17168bd6b83fb4c0ce381f366c4779f5 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Tue, 12 Jan 2021 14:52:17 +0530 Subject: [PATCH 32/72] Update test cases and add timeout --- Athos/tests/tf/unittests/test_arith_binops.py | 1 - Athos/tests/tf/unittests/test_convolution.py | 6 +++++ Athos/tests/tf/unittests/test_non_linear.py | 4 +-- .../tf/unittests/test_shape_manipulation.py | 9 +++---- Athos/tests/tf/unittests/test_unaryops.py | 20 +++++++++----- Athos/tests/utils.py | 26 +++++++++++++------ 6 files changed, 42 insertions(+), 24 deletions(-) diff --git a/Athos/tests/tf/unittests/test_arith_binops.py b/Athos/tests/tf/unittests/test_arith_binops.py index 3f7c2dfe..d6c415f0 100644 --- a/Athos/tests/tf/unittests/test_arith_binops.py +++ b/Athos/tests/tf/unittests/test_arith_binops.py @@ -164,7 +164,6 @@ def test_matmul( config = Config(backend).add_input(a).add_output(output) if not bisModel: config.add_input(b) - config.config["scale"] = 12 compiler = Compiler(graph, config, test_dir) mpc_output = compiler.compile_and_run([a_inp]) assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) diff --git a/Athos/tests/tf/unittests/test_convolution.py b/Athos/tests/tf/unittests/test_convolution.py index 0258c677..0301deaa 100644 --- a/Athos/tests/tf/unittests/test_convolution.py +++ b/Athos/tests/tf/unittests/test_convolution.py @@ -45,6 +45,8 @@ ) @pytest.mark.parametrize("dtype", [np.single]) def test_conv(test_dir, backend, tfOp, a_shape, kernel_shape, strides, padding, dtype): + if tfOp == tf.nn.conv3d and backend in ["2PC_HE", "2PC_OT"]: + pytest.skip("[conv3d] Missing Support in SCI") graph = tf.Graph() a_inp = dtype(np.random.randn(*a_shape)) kernel_inp = dtype(np.random.randn(*kernel_shape)) @@ -56,6 +58,7 @@ def test_conv(test_dir, backend, tfOp, a_shape, kernel_shape, strides, padding, expected_output = sess.run(output, feed_dict={a: a_inp}) config = Config(backend).add_input(a).add_output(output) + config.config["scale"] = 12 compiler = Compiler(graph, config, test_dir) mpc_output = compiler.compile_and_run([a_inp]) assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) @@ -96,6 +99,8 @@ def test_conv_transpose( padding, dtype, ): + if backend in ["2PC_HE", "2PC_OT"]: + pytest.skip("[conv3d] Missing Support in SCI") graph = tf.Graph() a_inp = dtype(np.random.randn(*a_shape)) kernel_inp = dtype(np.random.randn(*kernel_shape)) @@ -107,6 +112,7 @@ def test_conv_transpose( expected_output = sess.run(output, feed_dict={a: a_inp}) config = Config(backend).add_input(a).add_output(output) + config.config["scale"] = 12 compiler = Compiler(graph, config, test_dir) mpc_output = compiler.compile_and_run([a_inp]) assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) diff --git a/Athos/tests/tf/unittests/test_non_linear.py b/Athos/tests/tf/unittests/test_non_linear.py index 5fabae41..1c35d050 100644 --- a/Athos/tests/tf/unittests/test_non_linear.py +++ b/Athos/tests/tf/unittests/test_non_linear.py @@ -35,7 +35,7 @@ @pytest.mark.skip(reason="[non-linear] Haven't made non-linear functionalities public") -@pytest.mark.parametrize("a_shape", [(4, 4), (1,), ()]) +@pytest.mark.parametrize("a_shape", [[4, 4], [1], []]) @pytest.mark.parametrize("dtype", [np.single]) @pytest.mark.parametrize( "tfOp", @@ -63,7 +63,7 @@ def test_non_linear(test_dir, backend, tfOp, a_shape, dtype): return @pytest.mark.skip(reason="[softmax] Haven't made non-linear functionalities public") -@pytest.mark.parametrize("a_shape, axis", [((2, 3), 1), ((1,), 0)]) +@pytest.mark.parametrize("a_shape, axis", [([2, 3], 1), ([1], 0)]) @pytest.mark.parametrize("dtype", [np.single]) def test_softmax(test_dir, backend, a_shape, axis, dtype): graph = tf.Graph() diff --git a/Athos/tests/tf/unittests/test_shape_manipulation.py b/Athos/tests/tf/unittests/test_shape_manipulation.py index 1f73113a..6bbf9547 100644 --- a/Athos/tests/tf/unittests/test_shape_manipulation.py +++ b/Athos/tests/tf/unittests/test_shape_manipulation.py @@ -40,11 +40,9 @@ ([2, 3], [6]), ([6], [2, 3]), ([2, 3], [3, 2]), - ([2, 3], [-1]), # Flatten 1-D - pytest.param( - [1], [], marks=pytest.mark.skip(reason="[reshape] dumping weights error") - ), # convert to scalar - ([3, 2, 3], [2, -1]), # infer -1 as 9 + ([2, 3], [-1]), # Flatten 1-D, + ([1], []), # convert to scalar, + ([3, 2, 3], [2, -1]), # infer -1 as 9, ([3, 2, 3], [-1, 9]), # infer -1 as 2 ], ) @@ -126,7 +124,6 @@ def test_split(test_dir, backend, a_shape, num_or_size_splits, axis, dtype): # Squeeze -# TODO: also add a squeeze dim example. @pytest.mark.parametrize( "a_shape, axis", [ diff --git a/Athos/tests/tf/unittests/test_unaryops.py b/Athos/tests/tf/unittests/test_unaryops.py index 8e0b8c15..8eadc0c1 100644 --- a/Athos/tests/tf/unittests/test_unaryops.py +++ b/Athos/tests/tf/unittests/test_unaryops.py @@ -47,12 +47,12 @@ ), tf.shape, tf.identity, - pytest.param( - tf.zeros_like, marks=pytest.mark.skip(reason="[zeros_like] EzPC issue for inp=[2,2]") - ), + tf.zeros_like, ], ) def test_uop(test_dir, backend, tfOp, a_shape, dtype): + if backend.startswith("2PC") and tfOp == tf.math.square: + pytest.skip("[SCI][square] Secret Secret mul not implemented") graph = tf.Graph() a_inp = dtype(np.random.randn(*a_shape)) with graph.as_default(): @@ -80,7 +80,7 @@ def test_uop(test_dir, backend, tfOp, a_shape, dtype): ) @pytest.mark.parametrize("dtype", [np.single]) @pytest.mark.parametrize("tfOp", [tf.math.reduce_mean, tf.reduce_sum]) -@pytest.mark.skip(reason="[reduce] Reduce mean assert shape failure") +#@pytest.mark.skip(reason="[reduce] Reduce mean output mismatch and shape failure") def test_reduce(test_dir, backend, tfOp, a_shape, axis, keepdims, dtype): graph = tf.Graph() a_inp = dtype(np.random.randn(*a_shape)) @@ -103,10 +103,13 @@ def test_reduce(test_dir, backend, tfOp, a_shape, axis, keepdims, dtype): ([3, 2], None), ([3, 2], 0), ([3, 2], 1), + ([3, 2, 3], 1), + ([3, 2, 1, 1], 1), + ([3, 2], 1), ], ) @pytest.mark.parametrize("dtype", [np.single]) -@pytest.mark.skip(reason="[argmax] Generic argmax not implemented") +@pytest.mark.skip(reason="[argmax] Need support for argmax along arbitrary axis") def test_argmax(test_dir, backend, a_shape, axis, dtype): graph = tf.Graph() a_inp = dtype(np.random.randn(*a_shape)) @@ -143,6 +146,8 @@ def test_argmax(test_dir, backend, a_shape, axis, dtype): def test_pool( test_dir, backend, tfOp, a_shape, ksize, strides, padding, data_format, dtype ): + if backend.startswith("2PC") and tfOp == tf.nn.max_pool: + pytest.skip("[SCI][maxpool] Output mismatch bug") graph = tf.Graph() a_inp = dtype(np.random.randn(*a_shape)) with graph.as_default(): @@ -173,9 +178,10 @@ def test_pool( "from_dtype, to_dtype", [ (np.single, np.single), - ( + pytest.param( np.double, np.single, + marks=pytest.mark.skip(reason="[cast] Support for parsing DOUBLES"), ), pytest.param( np.single, @@ -212,6 +218,6 @@ def test_fill(test_dir, backend, a_shape, value): config = Config(backend).add_output(output) compiler = Compiler(graph, config, test_dir) - mpc_output = compiler.compile_and_run([]) + mpc_output = compiler.compile_and_run([], timeoutSeconds=60) assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) return \ No newline at end of file diff --git a/Athos/tests/utils.py b/Athos/tests/utils.py index 80e27e3f..9086685f 100644 --- a/Athos/tests/utils.py +++ b/Athos/tests/utils.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import tempfile import sys import os @@ -51,11 +51,13 @@ def __init__(self, mode): elif mode == "2PC_OT": self.config["target"] = "PORTHOS2PC" self.config["bitlength"] = 41 + self.config["scale"] = 12 self.config["backend"] = "OT" elif mode == "2PC_HE": self.config["target"] = "PORTHOS2PC" self.config["bitlength"] = 41 + self.config["scale"] = 12 self.config["backend"] = "HE" else: assert False, "Mode has to be one of CPP/3PC/2PC_OT/2PC_HE" @@ -129,7 +131,7 @@ def __init__(self, program_path, model_weight_path, params, test_dir): self.target = params["target"] self.test_dir = test_dir - def run(self, inputs): + def run(self, inputs, timeoutSeconds): # scale input and dump to file inputs_scaled = os.path.join( self.test_dir, "input_fixedpt_scale_" + str(self.scale) + ".inp" @@ -175,7 +177,10 @@ def run(self, inputs): commands = [client_cmd, server_cmd, party2_cmd] procs = [subprocess.Popen(i, shell=True) for i in commands] for p in procs: - p.wait() + try: + p.wait(timeoutSeconds) + except subprocess.TimeoutExpired: + p.kill() elif self.target == "PORTHOS2PC": util_dir = os.path.dirname(os.path.abspath(__file__)) sci_dir = os.path.join(util_dir, "..", "..", "SCI") @@ -195,7 +200,10 @@ def run(self, inputs): commands = [client_cmd, server_cmd] procs = [subprocess.Popen(i, shell=True) for i in commands] for p in procs: - p.wait() + try: + p.wait(timeoutSeconds) + except subprocess.TimeoutExpired: + p.kill() return convert_raw_output_to_np(raw_output, self.bitlength, self.scale) @@ -205,13 +213,15 @@ def __init__(self, graph, config, test_dir): self.config = config.config self.test_dir = test_dir - def compile_and_run(self, inputs): + def compile_and_run(self, inputs, timeoutSeconds=40): save_graph(self.graph_def, self.config, self.test_dir) params = get_params(self.config) print(params) - (output_program, model_weight_file) = CompileTFGraph.generate_code(params) + (output_program, model_weight_file) = CompileTFGraph.generate_code( + params, debug=False + ) prog = Program(output_program, model_weight_file, params, self.test_dir) - output = prog.run(inputs) + output = prog.run(inputs, timeoutSeconds) return output From 08421b2cf1a7cf91be7c73bdf8f0ebc446de35ad Mon Sep 17 00:00:00 2001 From: Bhatu Date: Tue, 12 Jan 2021 14:53:04 +0530 Subject: [PATCH 33/72] Add debug flag in compile script --- Athos/CompileTFGraph.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index f470557c..1d4e5e00 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -79,7 +79,7 @@ def parse_args(): args = parser.parse_args() return args -def generate_code(params): +def generate_code(params, debug=False): # Mandatory model_name = params["model_name"] input_tensor_info = params["input_tensors"] @@ -201,8 +201,9 @@ def generate_code(params): "eval `opam config env`; ./ezpc.sh {} ".format(ezpc_file_name) + ezpc_args ) os.system( - "cp {output} {model_dir} ".format(output=output_name, model_dir=model_abs_dir) + "mv {output} {model_dir} ".format(output=output_name, model_dir=model_abs_dir) ) + os.system("rm {}".format(ezpc_file_name)) output_file = os.path.join(model_abs_dir, output_name) if target == "PORTHOS2PC": @@ -211,20 +212,24 @@ def generate_code(params): program_name = model_base_name + "_" + target + ".out" program_path = os.path.join(model_abs_dir, program_name) os.chdir(model_abs_dir) + if debug: + opt_flag = "-O0 -g" + else: + opt_flag = "-O3" if target in [ "CPP", "CPPRING"]: os.system( - "g++ -O3 -w {file} -o {output}".format(file=output_file, output=program_path) + "g++ {opt_flag} -w {file} -o {output}".format(file=output_file, output=program_path, opt_flag=opt_flag) ) elif target == "PORTHOS": porthos_src = os.path.join(athos_dir, "..", "Porthos", "src") porthos_lib = os.path.join(porthos_src, "build", "lib") if os.path.exists(porthos_lib): os.system( - """g++ -O3 -fopenmp -pthread -w -march=native -msse4.1 -maes -mpclmul \ + """g++ {opt_flag} -fopenmp -pthread -w -march=native -msse4.1 -maes -mpclmul \ -mrdseed -fpermissive -fpic -std=c++17 -L {porthos_lib} -I {porthos_headers} {file} \ -lPorthos-Protocols -lssl -lcrypto -lrt -lboost_system \ -o {output}""".format(porthos_lib=porthos_lib, porthos_headers=porthos_src, - file=output_file, output=program_path) + file=output_file, output=program_path, opt_flag=opt_flag) ) else: print("Not compiling generated code. Please follow the readme and build Porthos.") @@ -236,11 +241,11 @@ def generate_code(params): seal_lib_path = os.path.join(sci, "extern", "SEAL", "native", "lib") if os.path.exists(sci_lib): os.system( - """g++ -O3 -fpermissive -pthread -w -maes -msse4.1 -mavx -mavx2 -mrdseed \ + """g++ {opt_flag} -fpermissive -pthread -w -maes -msse4.1 -mavx -mavx2 -mrdseed \ -faligned-new -std=c++17 -fopenmp -I {eigen} -I {sci_src} {file} \ -L {sci_lib} -lSCI-LinearHE -L {seal} -lseal -lssl -lcrypto \ -o {output}""".format(eigen=eigen_path, sci_src=sci_src, - file=output_file,sci_lib=sci_lib,seal=seal_lib_path, output=program_path) + file=output_file,sci_lib=sci_lib,seal=seal_lib_path, output=program_path, opt_flag=opt_flag) ) else: print("Not compiling generated code. Please follow the readme and build SCI.") From fd7061db9e8c0d82056799a13d34b993afd55170 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Tue, 12 Jan 2021 14:53:28 +0530 Subject: [PATCH 34/72] Fix reduce_mean/sum of tensors to 1 dimension. Closes #96 --- Athos/SeeDot/IR/IRBuilderCSF.py | 5 ++++- Athos/TFCompiler/TFNodesAST.py | 12 ++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index a8279ae6..5561cede 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -1289,11 +1289,14 @@ def visitReduce(self, node:AST.Reduce, args=None): if node.keepdims == 1: calculated_shape.append(1) outputiters.append(IR.Int(0,32)) + if calculated_shape == []: + calculated_shape = [1] + outputiters.append(IR.Int(0,32)) # perm will now be [ 1 ,2 ] + [ 0, 3] perm.extend(reduced_dims) loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))] outputShape = node.type.shape - assert(calculated_shape == outputShape) + assert calculated_shape == outputShape, "calculate shape:{} - real_shape: {}".format(calculated_shape, outputShape) sumExpr = self.getTempVar() sumExpr_decl = IR.Decl(sumExpr.idf, Type.Int()) diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index e1345b12..fe934564 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -589,7 +589,11 @@ def Sum(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dic reductionAxesNodeName = inputsRef[1] redAxesN = graph.__getitem__(reductionAxesNodeName) redAxesT = redAxesN.getAttrVal("value").getTensor() - reductionAxesList = redAxesT.getContentAsValArr() + rank = redAxesT.getShapeRef().getRank() + if rank != 0: + reductionAxesList = redAxesT.getContentAsValArr() + else: + reductionAxesList = [redAxesT.getConstantVal()] curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] return (None, { curNode.getName() : AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), @@ -609,7 +613,11 @@ def Mean(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : di reductionAxesNodeName = inputsRef[1] redAxesN = graph.__getitem__(reductionAxesNodeName) redAxesT = redAxesN.getAttrVal("value").getTensor() - reductionAxesList = redAxesT.getContentAsValArr() + rank = redAxesT.getShapeRef().getRank() + if rank != 0: + reductionAxesList = redAxesT.getContentAsValArr() + else: + reductionAxesList = [redAxesT.getConstantVal()] curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] return (None, { curNode.getName() : AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), From a489c49d5c92f51df0277a7e5751e1b8baeb0bc1 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Tue, 12 Jan 2021 16:51:23 +0530 Subject: [PATCH 35/72] Fixes output mismatch in reduce_mean and reduce_sum. Closes #97 --- Athos/.gitignore | 1 - Athos/SeeDot/IR/IRBuilderCSF.py | 56 +++++++++++++++++------ Athos/tests/.gitignore | 2 + Athos/tests/tf/unittests/test_unaryops.py | 3 ++ 4 files changed, 46 insertions(+), 16 deletions(-) create mode 100644 Athos/tests/.gitignore diff --git a/Athos/.gitignore b/Athos/.gitignore index c36857d5..617a7ab9 100644 --- a/Athos/.gitignore +++ b/Athos/.gitignore @@ -10,4 +10,3 @@ SeeDot/debug/ *__temp1.ezpc *__temp2.ezpc __pycache__/ -tests/debug diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index 5561cede..8aa3370b 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -1278,9 +1278,19 @@ def visitReduce(self, node:AST.Reduce, args=None): outputiters = [] no_elems = 1 j = 0 + for i in range(len(inputShape)): if i not in reduced_dims: perm.append(i) + # perm will now be [ 1 ,2 ] + [ 0, 3] + perm.extend(reduced_dims) + print(perm) + print(reduced_dims) + loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))] + shuffled_inputiters = [inputiters[perm[i]] for i in range(len(inputShape))] + + for i in range(len(inputShape)): + if i not in reduced_dims: calculated_shape.append(inputShape[i]) outputiters.append(inputiters[j]) j = j + 1 @@ -1289,30 +1299,35 @@ def visitReduce(self, node:AST.Reduce, args=None): if node.keepdims == 1: calculated_shape.append(1) outputiters.append(IR.Int(0,32)) + if calculated_shape == []: calculated_shape = [1] outputiters.append(IR.Int(0,32)) - # perm will now be [ 1 ,2 ] + [ 0, 3] - perm.extend(reduced_dims) - loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))] + outputShape = node.type.shape assert calculated_shape == outputShape, "calculate shape:{} - real_shape: {}".format(calculated_shape, outputShape) sumExpr = self.getTempVar() sumExpr_decl = IR.Decl(sumExpr.idf, Type.Int()) initSumCmd = IR.Assn(sumExpr, IRUtil.zero) - updateSumCmd = IR.Assn(sumExpr, IRUtil.add(sumExpr, IRUtil.addIndex(expr1, inputiters))) - - outer_nesting = len(inputShape) - len(reduced_dims) - temp_flat = self.getTempVar() - temp_flat_decl = IR.Decl(temp_flat.idf, - Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint), - isSecret=node.type.isSecret) - # i1*s2 + i2 - flat_idx_expr = IRUtil.getFlatArrIdxExpr(inputiters[:outer_nesting], loop_shape[:outer_nesting]) - # temp_flat[i1*s2 + i2] = sum - temp_flat_expr = IRUtil.addIndex(temp_flat, [flat_idx_expr]) - updateOutCmd = IR.Assn(temp_flat_expr, sumExpr) + updateSumCmd = IR.Assn(sumExpr, IRUtil.add(sumExpr, IRUtil.addIndex(expr1, shuffled_inputiters))) + + if node.op == AST.Operators.Mean: + outer_nesting = len(inputShape) - len(reduced_dims) + temp_flat = self.getTempVar() + temp_flat_decl = IR.Decl(temp_flat.idf, + Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint), + isSecret=node.type.isSecret) + # i1*s2 + i2 + flat_idx_expr = IRUtil.getFlatArrIdxExpr(inputiters[:outer_nesting], loop_shape[:outer_nesting]) + # temp_flat[i1*s2 + i2] = sum + temp_flat_expr = IRUtil.addIndex(temp_flat, [flat_idx_expr]) + updateOutCmd = IR.Assn(temp_flat_expr, sumExpr) + elif node.op == AST.Operators.ADD: + output = self.getTempVar() + output_decl = IR.Decl(output.idf, node.type) + out_expr = IRUtil.addIndex(output, outputiters) + updateOutCmd = IR.Assn(out_expr, sumExpr) # Generate the sum loop inner_loops_processed = 0 @@ -1323,6 +1338,17 @@ def visitReduce(self, node:AST.Reduce, args=None): if(inner_loops_processed == len(reduced_dims)): sum_loop = [initSumCmd] + sum_loop + [updateOutCmd] + if node.op == AST.Operators.ADD: + comment = IR.Comment(str(node.metadata)) + final_prog = IRUtil.prog_merge( prog_1, + IR.Prog([comment]), + IR.Prog([sumExpr_decl, output_decl]), + IR.Prog(sum_loop)) + if not(Util.Config.disableTruncOpti): + self.scaleFacMapping[output.idf] = self.scaleFacMapping[expr1.idf] + + return (final_prog, output) + # Insert call to ElemWiseVectorPublicDiv(size=s1*s2, inp=temp_flat, divisor=s0*s3, out=out_flat) out_flat = self.getTempVar() out_flat_decl = IR.Decl(out_flat.idf, diff --git a/Athos/tests/.gitignore b/Athos/tests/.gitignore new file mode 100644 index 00000000..169aaf56 --- /dev/null +++ b/Athos/tests/.gitignore @@ -0,0 +1,2 @@ +results-Porthos2PC-server.csv +debug diff --git a/Athos/tests/tf/unittests/test_unaryops.py b/Athos/tests/tf/unittests/test_unaryops.py index 8eadc0c1..87e9b77c 100644 --- a/Athos/tests/tf/unittests/test_unaryops.py +++ b/Athos/tests/tf/unittests/test_unaryops.py @@ -75,6 +75,9 @@ def test_uop(test_dir, backend, tfOp, a_shape, dtype): ([3, 2], [0, 1], False), ([3, 2], 0, False), ([3, 2], 1, False), + ([3, 2, 4], 1, False), + ([3, 2, 4], [1, 2], False), + ([3, 2, 4], [2, 1], False), ([3, 2], 0, True), ], ) From 5f699854fb06183db3cadb0b9fb5d5f71f655448 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 20 Jan 2021 10:32:15 +0530 Subject: [PATCH 36/72] Add script to compile and run models in Networks directory. Usage: python CompileSampleNetworks.py Networks/sample_network.config This will compile and run ResNet by creating a tmux session named ResNet. Do python CompileSampleNetworks.py --help to see full usage. --- Athos/CompileSampleNetworks.py | 341 +++++++++++++++ Athos/CompileTFGraph.py | 397 +++++++++--------- Athos/CompilerScripts/parse_config.py | 47 ++- .../sample_networks/print_stats_2pc.sh | 71 ++++ .../sample_networks/print_stats_3pc.sh | 64 +++ .../sample_networks/print_stats_cpp.sh | 46 ++ .../sample_networks/run_demo_2pc.sh | 66 +++ .../sample_networks/run_demo_3pc.sh | 73 ++++ .../sample_networks/run_demo_cpp.sh | 50 +++ Athos/HelperScripts/SetupCIFAR10.sh | 9 +- Athos/Networks/.gitignore | 3 + Athos/Networks/ChestXRay/ChestXRay_tf_main.py | 27 +- Athos/Networks/ChestXRay/setup_and_run.sh | 52 +++ Athos/Networks/DenseNet/.gitignore | 1 + Athos/Networks/DenseNet/DenseNet_main.py | 25 +- Athos/Networks/DenseNet/setup_and_run.sh | 51 +++ Athos/Networks/ResNet/ResNet_main.py | 29 +- Athos/Networks/ResNet/setup_and_run.sh | 48 +++ .../SqueezeNetCIFAR10/Squeezenet_model.py | 22 +- .../SqueezeNetCIFAR10/setup_and_run.sh | 59 +++ .../SqueezeNetImgNet/setup_and_run.sh | 43 ++ .../SqueezeNetImgNet/squeezenet_main.py | 27 +- Athos/Networks/sample_network.config | 5 + Athos/SeeDot/Compiler.py | 2 +- Athos/TFCompiler/ProcessTFGraph.py | 5 +- 25 files changed, 1331 insertions(+), 232 deletions(-) create mode 100644 Athos/CompileSampleNetworks.py create mode 100755 Athos/CompilerScripts/sample_networks/print_stats_2pc.sh create mode 100755 Athos/CompilerScripts/sample_networks/print_stats_3pc.sh create mode 100755 Athos/CompilerScripts/sample_networks/print_stats_cpp.sh create mode 100755 Athos/CompilerScripts/sample_networks/run_demo_2pc.sh create mode 100755 Athos/CompilerScripts/sample_networks/run_demo_3pc.sh create mode 100755 Athos/CompilerScripts/sample_networks/run_demo_cpp.sh mode change 100644 => 100755 Athos/HelperScripts/SetupCIFAR10.sh create mode 100644 Athos/Networks/.gitignore create mode 100755 Athos/Networks/ChestXRay/setup_and_run.sh create mode 100644 Athos/Networks/DenseNet/.gitignore create mode 100755 Athos/Networks/DenseNet/setup_and_run.sh create mode 100755 Athos/Networks/ResNet/setup_and_run.sh create mode 100755 Athos/Networks/SqueezeNetCIFAR10/setup_and_run.sh create mode 100755 Athos/Networks/SqueezeNetImgNet/setup_and_run.sh create mode 100644 Athos/Networks/sample_network.config diff --git a/Athos/CompileSampleNetworks.py b/Athos/CompileSampleNetworks.py new file mode 100644 index 00000000..5c474ef7 --- /dev/null +++ b/Athos/CompileSampleNetworks.py @@ -0,0 +1,341 @@ +""" + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +""" +import argparse +from argparse import RawTextHelpFormatter + +import os +import os.path +import json +import sys + +import TFCompiler.ProcessTFGraph as Athos +import CompilerScripts.parse_config as parse_config + + +def parse_args(): + parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter) + parser.add_argument( + "--config", + required=True, + type=str, + help="""Path to the config json file +Config file should be a json in the following format: +{ + //--------------------------- Mandatory options --------------------------- + + "network_name":"ResNet", // Any network name from Athos/Networks directory (ResNet/DenseNet/ChestXRay/..) + "target":"PORTHOS2PC", // Compilation target. ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC + + + + //--------------------------- Optional options --------------------------- + "scale":10, // Scaling factor to compile for. DEFAULT=12. + "bitlength":64, // Bit length to compile for. DEFAULT=64. + + "modulo" : 32, // Modulo to be used for shares. Applicable for + // CPPRING/PORTHOS2PC backend. For + // PORTHOS2PC + backend=OT => Power of 2 + // PORTHOS2PC + backend=HE => Prime value." + + "backend" : "OT", // Backend to be used - OT/HE (DEFAULT=OT). + // Only applicable for PORTHOS2PC backend + + "disable_all_hlil_opts" : false, // Disable all optimizations in HLIL. DEFAULT=false + "disable_relu_maxpool_opts" : false, // Disable Relu-Maxpool optimization. DEFAULT=false + "disable_garbage_collection" : false, // Disable Garbage Collection optimization. DEFAULT=false + "disable_trunc_opts" : false // Disable truncation placement optimization. DEFAULT=false + "run_in_tmux" : false // Also run the compiled program in a new tmux session +} +""", + ) + args = parser.parse_args() + return args + + +def generate_code(params, debug=False): + network_name = params["network_name"] + assert network_name in [ + "ResNet", + "DenseNet", + "SqueezeNetImgNet", + "SqueezeNetCIFAR10" + ], "Network must be any of ResNet/DenseNet/SqueezeNetImgNet/SqueezeNetCIFAR10" + scale = 12 if params["scale"] is None else params["scale"] + bitlength = 64 if params["bitlength"] is None else params["bitlength"] + target = params["target"] + disable_all_hlil_opts = ( + False + if params["disable_all_hlil_opts"] is None + else params["disable_all_hlil_opts"] + ) + disable_relu_maxpool_opts = ( + False + if params["disable_relu_maxpool_opts"] is None + else params["disable_relu_maxpool_opts"] + ) + disable_garbage_collection = ( + False + if params["disable_garbage_collection"] is None + else params["disable_garbage_collection"] + ) + disable_trunc_opts = ( + False if params["disable_trunc_opts"] is None else params["disable_trunc_opts"] + ) + modulo = params["modulo"] + backend = "OT" if params["backend"] is None else params["backend"] + run_in_tmux = False if params["run_in_tmux"] is None else params["run_in_tmux"] + + assert bitlength <= 64 and bitlength >= 1, "Bitlen must be >= 1 and <= 64" + assert target in [ + "PORTHOS", + "PORTHOS2PC", + "ABY", + "CPP", + "CPPRING", + ], "Target must be any of ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC" + + cwd = os.getcwd() + athos_dir = os.path.dirname(os.path.abspath(__file__)) + model_abs_dir = os.path.join(athos_dir, "Networks", network_name) + if not os.path.exists(model_abs_dir): + sys.exit("Model directory {} does not exist".format(model_abs_dir)) + + # Generate graphdef and sizeInfo metadata, and dump model weights + os.chdir(model_abs_dir) + os.system("./setup_and_run.sh {scale}".format(scale=scale)) + os.chdir(cwd) + print( + "--------------------------------------------------------------------------------" + ) + print("Compiling model to EzPC") + print( + "--------------------------------------------------------------------------------" + ) + + # Compile to seedot. Generate AST in model directory + print("model_dir = ", model_abs_dir) + Athos.process_tf_graph(model_abs_dir) + + # Compile to ezpc + model_base_name = network_name + ezpc_file_name = "{mname}_{bl}_{target}.ezpc".format( + mname=model_base_name, bl=bitlength, target=target.lower() + ) + ezpc_abs_path = os.path.join(model_abs_dir, ezpc_file_name) + + seedot_args = "" + seedot_args += "--astFile {}/astOutput.pkl --consSF {} ".format( + model_abs_dir, scale + ) + seedot_args += "--bitlen {} --outputFileName {} ".format(bitlength, ezpc_abs_path) + seedot_args += "--disableAllOpti {} ".format(disable_all_hlil_opts) + seedot_args += "--disableRMO {} ".format(disable_relu_maxpool_opts) + seedot_args += "--disableLivenessOpti {} ".format(disable_garbage_collection) + seedot_args += "--disableTruncOpti {} ".format(disable_trunc_opts) + + seedot_script = os.path.join(athos_dir, "SeeDot", "SeeDot.py") + print("python3 {} ".format(seedot_script) + seedot_args) + os.system("python3 {} ".format(seedot_script) + seedot_args) + + # Add library functions + if target in ["ABY", "CPPRING"]: + library = "cpp" + else: + library = target.lower() + + lib_bitlength = 64 if bitlength > 32 else 32 + library_dir = os.path.join(athos_dir, "TFEzPCLibrary") + common = os.path.join(library_dir, "Library{}_common.ezpc".format(lib_bitlength)) + if library == "cpp": + pre = os.path.join( + library_dir, "Library{}_{}_pre.ezpc".format(lib_bitlength, library) + ) + post = os.path.join( + library_dir, "Library{}_{}_post.ezpc".format(lib_bitlength, library) + ) + else: + pre = os.path.join( + library_dir, "Library{}_{}.ezpc".format(lib_bitlength, library) + ) + post = "" + temp = os.path.join(model_abs_dir, "temp.ezpc") + os.system( + "cat {pre} {common} {post} {ezpc}> {temp}".format( + pre=pre, common=common, post=post, ezpc=ezpc_abs_path, temp=temp + ) + ) + os.system("mv {temp} {ezpc}".format(temp=temp, ezpc=ezpc_abs_path)) + + ezpc_dir = os.path.join(athos_dir, "../EzPC/EzPC/") + # Copy generated code to the ezpc directory + os.system("cp {ezpc} {ezpc_dir}".format(ezpc=ezpc_abs_path, ezpc_dir=ezpc_dir)) + os.chdir(ezpc_dir) + ezpc_args = "" + ezpc_args += "--bitlen {bl} --codegen {target} --disable-tac ".format( + bl=bitlength, target=target + ) + output_name = ezpc_file_name[:-5] + "0.cpp" + if modulo is not None: + ezpc_args += "--modulo {} ".format(modulo) + if target == "PORTHOS2PC": + ezpc_args += "--backend {} ".format(backend.upper()) + output_name = ezpc_file_name[:-5] + "_{}0.cpp".format(backend.upper()) + if target in ["PORTHOS"]: + ezpc_args += "--sf {} ".format(scale) + + os.system( + "eval `opam config env`; ./ezpc.sh {} ".format(ezpc_file_name) + ezpc_args + ) + os.system( + "mv {output} {model_dir} ".format(output=output_name, model_dir=model_abs_dir) + ) + os.system("rm {}".format(ezpc_file_name)) + output_file = os.path.join(model_abs_dir, output_name) + + print( + "--------------------------------------------------------------------------------" + ) + print("Compiling generated {} code".format(target)) + print( + "--------------------------------------------------------------------------------" + ) + if target == "PORTHOS2PC": + program_name = model_base_name + "_" + target + "_" + backend + ".out" + else: + program_name = model_base_name + "_" + target + ".out" + + program_path = os.path.join(model_abs_dir, program_name) + os.chdir(model_abs_dir) + + if debug: + opt_flag = "-O0 -g" + else: + opt_flag = "-O3" + + if target in ["CPP", "CPPRING"]: + os.system( + "g++ {opt_flag} -w {file} -o {output}".format( + file=output_file, output=program_path, opt_flag=opt_flag + ) + ) + elif target == "PORTHOS": + porthos_src = os.path.join(athos_dir, "..", "Porthos", "src") + porthos_lib = os.path.join(porthos_src, "build", "lib") + if os.path.exists(porthos_lib): + os.system( + """g++ {opt_flag} -fopenmp -pthread -w -march=native -msse4.1 -maes -mpclmul \ + -mrdseed -fpermissive -fpic -std=c++17 -L {porthos_lib} -I {porthos_headers} {file} \ + -lPorthos-Protocols -lssl -lcrypto -lrt -lboost_system \ + -o {output}""".format( + porthos_lib=porthos_lib, + porthos_headers=porthos_src, + file=output_file, + output=program_path, + opt_flag=opt_flag, + ) + ) + else: + print( + "Not compiling generated code. Please follow the readme and build Porthos." + ) + elif target == "PORTHOS2PC": + sci = os.path.join(athos_dir, "..", "SCI") + sci_src = os.path.join(sci, "src") + sci_lib = os.path.join(sci, "build", "lib") + eigen_path = os.path.join(sci, "extern", "eigen") + seal_lib_path = os.path.join(sci, "extern", "SEAL", "native", "lib") + if os.path.exists(sci_lib): + os.system( + """g++ {opt_flag} -fpermissive -pthread -w -maes -msse4.1 -mavx -mavx2 -mrdseed \ + -faligned-new -std=c++17 -fopenmp -I {eigen} -I {sci_src} {file} \ + -L {sci_lib} -lSCI-LinearHE -L {seal} -lseal -lssl -lcrypto \ + -o {output}""".format( + eigen=eigen_path, + sci_src=sci_src, + file=output_file, + sci_lib=sci_lib, + seal=seal_lib_path, + output=program_path, + opt_flag=opt_flag, + ) + ) + else: + print( + "Not compiling generated code. Please follow the readme and build SCI before running this script." + ) + + os.chdir(cwd) + + input_path = os.path.join(model_abs_dir, "model_input_scale_{}.inp".format(scale)) + weights_path = os.path.join( + model_abs_dir, "model_weights_scale_{}.inp".format(scale) + ) + # program_path + if run_in_tmux: + is_tmux_installed = os.system("type tmux > /dev/null") + if is_tmux_installed != 0: + print( + "Not running the program. Tmux is not installed. Please install tmux and run or do the following manually to run." + ) + return + + print( + "--------------------------------------------------------------------------------" + ) + mode = target + " - " + backend if target == "PORTHOS2PC" else target + print("Running program securely in {} mode".format(mode)) + print( + "--------------------------------------------------------------------------------" + ) + + sample_networks_dir = os.path.join( + athos_dir, "CompilerScripts", "sample_networks" + ) + if target in ["CPP", "CPPRING"]: + run_script_path = os.path.join(sample_networks_dir, "run_demo_cpp.sh") + elif target == "PORTHOS": + run_script_path = os.path.join(sample_networks_dir, "run_demo_3pc.sh") + elif target == "PORTHOS2PC": + run_script_path = os.path.join(sample_networks_dir, "run_demo_2pc.sh") + os.system( + "{script} {model_dir} {model_binary} {model_input} {model_weight}".format( + script=run_script_path, + model_dir=model_abs_dir, + model_binary=program_path, + model_input=input_path, + model_weight=weights_path, + ) + ) + print( + "\nAttach to tmux session named {model} to see results (tmux a -t {model})".format( + model=network_name + ) + ) + return + + +if __name__ == "__main__": + args = parse_args() + params = parse_config.get_params(args.config, sample_network=True) + generate_code(params) diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index 1d4e5e00..e9379f05 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import argparse from argparse import RawTextHelpFormatter @@ -35,225 +35,246 @@ def parse_args(): - parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter) - parser.add_argument( - "--config", - required=True, - type=str, - help="""Path to the config json file + parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter) + parser.add_argument( + "--config", + required=True, + type=str, + help="""Path to the config json file Config file should be a json in the following format: { - // Mandatory options + //--------------------------- Mandatory options --------------------------- "model_name":"model.pb", // Tensorflow protobuf file to compile. "output_tensors":[ - "output1", - "output2" + "output1", + "output2" ], "target":"PORTHOS2PC", // Compilation target. ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC + + - // Optional options - "scale":10, // Scaling factor to compile for. Defaults to 12. - "bitlength":64, // Bit length to compile for. Defaults to 64. - "save_weights" : true, // Save model scaled weights in fixed point. Defaults to false. + //--------------------------- Optional options --------------------------- + "scale":10, // Scaling factor to compile for. DEFAULT=12. + "bitlength":64, // Bit length to compile for. DEFAULT=12. + "save_weights" : true, // Save model scaled weights in fixed point. DEFAULT=false. - "input_tensors":{ // Name and shape of the input tensors - "actual_input_1":"224,244,3", // for the model. Not required if the - "input2":"2,245,234,3" // placeholder nodes have shape info. + "input_tensors":{ // Name and shape of the input tensors + "actual_input_1":"224,244,3", // for the model. Not required if the + "input2":"2,245,234,3" // placeholder nodes have shape info in the .pb file. }, "modulo" : 32, // Modulo to be used for shares. Applicable for - // CPPRING/PORTHOS2PC backend. For - // PORTHOS2PC + backend=OT => Power of 2 - // PORTHOS2PC + backend=HE => Prime value." + // CPPRING/PORTHOS2PC backend. For + // PORTHOS2PC + backend=OT => Power of 2 + // PORTHOS2PC + backend=HE => Prime value." - "backend" : "OT", // Backend to be used - OT/HE (default OT). - // Only applicable for PORTHOS2PC backend + "backend" : "OT", // Backend to be used - OT/HE (default OT). + // Only applicable for PORTHOS2PC backend - "disable_all_hlil_opts" : false, // Disable all optimizations in HLIL - "disable_relu_maxpool_opts" : false, // Disable Relu-Maxpool optimization - "disable_garbage_collection" : false, // Disable Garbage Collection optimization - "disable_trunc_opts" : false // Disable truncation placement optimization + "disable_all_hlil_opts" : false, // Disable all optimizations in HLIL. DEFAULT=false + "disable_relu_maxpool_opts" : false, // Disable Relu-Maxpool optimization. DEFAULT=false + "disable_garbage_collection" : false, // Disable Garbage Collection optimization. DEFAULT=false + "disable_trunc_opts" : false // Disable truncation placement optimization. DEFAULT=false } """, - ) - args = parser.parse_args() - return args + ) + args = parser.parse_args() + return args + def generate_code(params, debug=False): - # Mandatory - model_name = params["model_name"] - input_tensor_info = params["input_tensors"] - output_tensors = params["output_tensors"] - scale = 12 if params["scale"] is None else params["scale"] - bitlength = 64 if params["bitlength"] is None else params["bitlength"] - target = params["target"] - save_weights = params["save_weights"] - save_weights = False if save_weights is None else save_weights + model_name = params["model_name"] + input_tensor_info = params["input_tensors"] + output_tensors = params["output_tensors"] + scale = 12 if params["scale"] is None else params["scale"] + bitlength = 64 if params["bitlength"] is None else params["bitlength"] + target = params["target"] + save_weights = False if params["save_weights"] is None else params["save_weights"] + disable_all_hlil_opts = ( + False + if params["disable_all_hlil_opts"] is None + else params["disable_all_hlil_opts"] + ) + disable_relu_maxpool_opts = ( + False + if params["disable_relu_maxpool_opts"] is None + else params["disable_relu_maxpool_opts"] + ) + disable_garbage_collection = ( + False + if params["disable_garbage_collection"] is None + else params["disable_garbage_collection"] + ) + disable_trunc_opts = ( + False if params["disable_trunc_opts"] is None else params["disable_trunc_opts"] + ) + modulo = params["modulo"] + backend = "OT" if params["backend"] is None else params["backend"] + assert bitlength <= 64 and bitlength >= 1, "Bitlen must be >= 1 and <= 64" + assert target in [ + "PORTHOS", + "PORTHOS2PC", + "ABY", + "CPP", + "CPPRING", + ], "Target must be any of ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC" - assert bitlength <= 64 and bitlength >= 1, "Bitlen must be >= 1 and <= 64" - assert target in [ - "PORTHOS", - "PORTHOS2PC", - "ABY", - "CPP", - "CPPRING", - ], "Target must be any of ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC" + cwd = os.getcwd() + athos_dir = os.path.dirname(os.path.abspath(__file__)) + model_abs_path = os.path.abspath(model_name) + model_abs_dir = os.path.dirname(model_abs_path) + # Generate graphdef and sizeInfo metadata + weights_path = compile_tf.compile( + model_name, input_tensor_info, output_tensors, scale, save_weights + ) - cwd = os.getcwd() - athos_dir = os.path.dirname(os.path.abspath(__file__)) - model_abs_path = os.path.abspath(model_name) - model_abs_dir = os.path.dirname(model_abs_path) - # Generate graphdef and sizeInfo metadata - weights_path = compile_tf.compile( - model_name, input_tensor_info, output_tensors, scale, save_weights - ) + # Compile to seedot. Generate AST in model directory + Athos.process_tf_graph(model_abs_path) - # Compile to seedot. Generate AST in model directory - Athos.process_tf_graph(model_abs_path) + # Compile to ezpc + model_base_name = os.path.basename(model_abs_path)[:-3] + ezpc_file_name = "{mname}_{bl}_{target}.ezpc".format( + mname=model_base_name, bl=bitlength, target=target.lower() + ) + ezpc_abs_path = os.path.join(model_abs_dir, ezpc_file_name) - # Compile to ezpc - model_base_name = os.path.basename(model_abs_path)[:-3] - ezpc_file_name = "{mname}_{bl}_{target}.ezpc".format( - mname=model_base_name, bl=bitlength, target=target.lower() - ) - ezpc_abs_path = os.path.join(model_abs_dir, ezpc_file_name) - disable_all_hlil_opts = ( - False - if params["disable_all_hlil_opts"] is None - else params["disable_all_hlil_opts"] - ) - disable_relu_maxpool_opts = ( - False - if params["disable_relu_maxpool_opts"] is None - else params["disable_relu_maxpool_opts"] - ) - disable_garbage_collection = ( - False - if params["disable_garbage_collection"] is None - else params["disable_garbage_collection"] - ) - disable_trunc_opts = ( - False if params["disable_trunc_opts"] is None else params["disable_trunc_opts"] - ) - seedot_args = "" - seedot_args += "--astFile {}/astOutput.pkl --consSF {} ".format( - model_abs_dir, scale - ) - seedot_args += "--bitlen {} --outputFileName {} ".format(bitlength, ezpc_abs_path) - seedot_args += "--disableAllOpti {} ".format(disable_all_hlil_opts) - seedot_args += "--disableRMO {} ".format(disable_relu_maxpool_opts) - seedot_args += "--disableLivenessOpti {} ".format(disable_garbage_collection) - seedot_args += "--disableTruncOpti {} ".format(disable_trunc_opts) + seedot_args = "" + seedot_args += "--astFile {}/astOutput.pkl --consSF {} ".format( + model_abs_dir, scale + ) + seedot_args += "--bitlen {} --outputFileName {} ".format(bitlength, ezpc_abs_path) + seedot_args += "--disableAllOpti {} ".format(disable_all_hlil_opts) + seedot_args += "--disableRMO {} ".format(disable_relu_maxpool_opts) + seedot_args += "--disableLivenessOpti {} ".format(disable_garbage_collection) + seedot_args += "--disableTruncOpti {} ".format(disable_trunc_opts) - seedot_script = os.path.join(athos_dir, "SeeDot", "SeeDot.py") - print("python3 {} ".format(seedot_script) + seedot_args) - os.system("python3 {} ".format(seedot_script) + seedot_args) + seedot_script = os.path.join(athos_dir, "SeeDot", "SeeDot.py") + print("python3 {} ".format(seedot_script) + seedot_args) + os.system("python3 {} ".format(seedot_script) + seedot_args) - # Add library functions - if target in ["ABY", "CPPRING"]: - library = "cpp" - else: - library = target.lower() + # Add library functions + if target in ["ABY", "CPPRING"]: + library = "cpp" + else: + library = target.lower() - lib_bitlength = 64 if bitlength > 32 else 32 - library_dir = os.path.join(athos_dir, "TFEzPCLibrary") - common = os.path.join(library_dir, "Library{}_common.ezpc".format(lib_bitlength)) - if library == "cpp": - pre = os.path.join( - library_dir, "Library{}_{}_pre.ezpc".format(lib_bitlength, library) - ) - post = os.path.join( - library_dir, "Library{}_{}_post.ezpc".format(lib_bitlength, library) - ) - else: - pre = os.path.join( - library_dir, "Library{}_{}.ezpc".format(lib_bitlength, library) - ) - post = "" - temp = os.path.join(model_abs_dir, "temp.ezpc") - os.system( - "cat {pre} {common} {post} {ezpc}> {temp}".format( - pre=pre, common=common, post=post, ezpc=ezpc_abs_path, temp=temp + lib_bitlength = 64 if bitlength > 32 else 32 + library_dir = os.path.join(athos_dir, "TFEzPCLibrary") + common = os.path.join(library_dir, "Library{}_common.ezpc".format(lib_bitlength)) + if library == "cpp": + pre = os.path.join( + library_dir, "Library{}_{}_pre.ezpc".format(lib_bitlength, library) + ) + post = os.path.join( + library_dir, "Library{}_{}_post.ezpc".format(lib_bitlength, library) + ) + else: + pre = os.path.join( + library_dir, "Library{}_{}.ezpc".format(lib_bitlength, library) + ) + post = "" + temp = os.path.join(model_abs_dir, "temp.ezpc") + os.system( + "cat {pre} {common} {post} {ezpc}> {temp}".format( + pre=pre, common=common, post=post, ezpc=ezpc_abs_path, temp=temp + ) ) - ) - os.system("mv {temp} {ezpc}".format(temp=temp, ezpc=ezpc_abs_path)) - - modulo = params["modulo"] - backend = "OT" if params["backend"] is None else params["backend"] - ezpc_dir = os.path.join(athos_dir, "../EzPC/EzPC/") - # Copy generated code to the ezpc directory - os.system("cp {ezpc} {ezpc_dir}".format(ezpc=ezpc_abs_path, ezpc_dir=ezpc_dir)) - os.chdir(ezpc_dir) - ezpc_args = "" - ezpc_args += "--bitlen {bl} --codegen {target} --disable-tac ".format( - bl=bitlength, target=target - ) - output_name = ezpc_file_name[:-5] + "0.cpp" - if modulo is not None: - ezpc_args += "--modulo {} ".format(modulo) - if target == "PORTHOS2PC": - ezpc_args += "--backend {} ".format(backend.upper()) - output_name = ezpc_file_name[:-5] + "_{}0.cpp".format(backend.upper()) - if target in ["PORTHOS"]: - ezpc_args += "--sf {} ".format(scale) + os.system("mv {temp} {ezpc}".format(temp=temp, ezpc=ezpc_abs_path)) - os.system( - "eval `opam config env`; ./ezpc.sh {} ".format(ezpc_file_name) + ezpc_args - ) - os.system( - "mv {output} {model_dir} ".format(output=output_name, model_dir=model_abs_dir) - ) - os.system("rm {}".format(ezpc_file_name)) - output_file = os.path.join(model_abs_dir, output_name) + ezpc_dir = os.path.join(athos_dir, "../EzPC/EzPC/") + # Copy generated code to the ezpc directory + os.system("cp {ezpc} {ezpc_dir}".format(ezpc=ezpc_abs_path, ezpc_dir=ezpc_dir)) + os.chdir(ezpc_dir) + ezpc_args = "" + ezpc_args += "--bitlen {bl} --codegen {target} --disable-tac ".format( + bl=bitlength, target=target + ) + output_name = ezpc_file_name[:-5] + "0.cpp" + if modulo is not None: + ezpc_args += "--modulo {} ".format(modulo) + if target == "PORTHOS2PC": + ezpc_args += "--backend {} ".format(backend.upper()) + output_name = ezpc_file_name[:-5] + "_{}0.cpp".format(backend.upper()) + if target in ["PORTHOS"]: + ezpc_args += "--sf {} ".format(scale) - if target == "PORTHOS2PC": - program_name = model_base_name + "_" + target + "_" + backend + ".out" - else: - program_name = model_base_name + "_" + target + ".out" - program_path = os.path.join(model_abs_dir, program_name) - os.chdir(model_abs_dir) - if debug: - opt_flag = "-O0 -g" - else: - opt_flag = "-O3" - if target in [ "CPP", "CPPRING"]: os.system( - "g++ {opt_flag} -w {file} -o {output}".format(file=output_file, output=program_path, opt_flag=opt_flag) + "eval `opam config env`; ./ezpc.sh {} ".format(ezpc_file_name) + ezpc_args + ) + os.system( + "mv {output} {model_dir} ".format(output=output_name, model_dir=model_abs_dir) ) - elif target == "PORTHOS": - porthos_src = os.path.join(athos_dir, "..", "Porthos", "src") - porthos_lib = os.path.join(porthos_src, "build", "lib") - if os.path.exists(porthos_lib): - os.system( - """g++ {opt_flag} -fopenmp -pthread -w -march=native -msse4.1 -maes -mpclmul \ + os.system("rm {}".format(ezpc_file_name)) + output_file = os.path.join(model_abs_dir, output_name) + + print("Compiling generated code to {target} target".format(target)) + if target == "PORTHOS2PC": + program_name = model_base_name + "_" + target + "_" + backend + ".out" + else: + program_name = model_base_name + "_" + target + ".out" + program_path = os.path.join(model_abs_dir, program_name) + os.chdir(model_abs_dir) + if debug: + opt_flag = "-O0 -g" + else: + opt_flag = "-O3" + if target in ["CPP", "CPPRING"]: + os.system( + "g++ {opt_flag} -w {file} -o {output}".format( + file=output_file, output=program_path, opt_flag=opt_flag + ) + ) + elif target == "PORTHOS": + porthos_src = os.path.join(athos_dir, "..", "Porthos", "src") + porthos_lib = os.path.join(porthos_src, "build", "lib") + if os.path.exists(porthos_lib): + os.system( + """g++ {opt_flag} -fopenmp -pthread -w -march=native -msse4.1 -maes -mpclmul \ -mrdseed -fpermissive -fpic -std=c++17 -L {porthos_lib} -I {porthos_headers} {file} \ -lPorthos-Protocols -lssl -lcrypto -lrt -lboost_system \ - -o {output}""".format(porthos_lib=porthos_lib, porthos_headers=porthos_src, - file=output_file, output=program_path, opt_flag=opt_flag) - ) - else: - print("Not compiling generated code. Please follow the readme and build Porthos.") - elif target == "PORTHOS2PC": - sci = os.path.join(athos_dir, "..", "SCI") - sci_src = os.path.join(sci, "src") - sci_lib = os.path.join(sci, "build", "lib") - eigen_path = os.path.join(sci, "extern", "eigen") - seal_lib_path = os.path.join(sci, "extern", "SEAL", "native", "lib") - if os.path.exists(sci_lib): - os.system( - """g++ {opt_flag} -fpermissive -pthread -w -maes -msse4.1 -mavx -mavx2 -mrdseed \ + -o {output}""".format( + porthos_lib=porthos_lib, + porthos_headers=porthos_src, + file=output_file, + output=program_path, + opt_flag=opt_flag, + ) + ) + else: + print( + "Not compiling generated code. Please follow the readme and build Porthos." + ) + elif target == "PORTHOS2PC": + sci = os.path.join(athos_dir, "..", "SCI") + sci_src = os.path.join(sci, "src") + sci_lib = os.path.join(sci, "build", "lib") + eigen_path = os.path.join(sci, "extern", "eigen") + seal_lib_path = os.path.join(sci, "extern", "SEAL", "native", "lib") + if os.path.exists(sci_lib): + os.system( + """g++ {opt_flag} -fpermissive -pthread -w -maes -msse4.1 -mavx -mavx2 -mrdseed \ -faligned-new -std=c++17 -fopenmp -I {eigen} -I {sci_src} {file} \ -L {sci_lib} -lSCI-LinearHE -L {seal} -lseal -lssl -lcrypto \ - -o {output}""".format(eigen=eigen_path, sci_src=sci_src, - file=output_file,sci_lib=sci_lib,seal=seal_lib_path, output=program_path, opt_flag=opt_flag) - ) - else: - print("Not compiling generated code. Please follow the readme and build SCI.") + -o {output}""".format( + eigen=eigen_path, + sci_src=sci_src, + file=output_file, + sci_lib=sci_lib, + seal=seal_lib_path, + output=program_path, + opt_flag=opt_flag, + ) + ) + else: + print( + "Not compiling generated code. Please follow the readme and build SCI." + ) + + os.chdir(cwd) + return (program_path, weights_path) - os.chdir(cwd) - return (program_path, weights_path) if __name__ == "__main__": - args = parse_args() - params = parse_config.get_params(args.config) - generate_code(params) \ No newline at end of file + args = parse_args() + params = parse_config.get_params(args.config) + generate_code(params) diff --git a/Athos/CompilerScripts/parse_config.py b/Athos/CompilerScripts/parse_config.py index a6189ed6..22624ce4 100644 --- a/Athos/CompilerScripts/parse_config.py +++ b/Athos/CompilerScripts/parse_config.py @@ -127,6 +127,15 @@ def get_str_list_param(config, p_name): assert type(i) == str, p_name + "is not a list of strings" return p +def get_opt_str_list_param(config, p_name): + p = config.get(p_name) + if p is None: + return p + assert type(p) == list, p_name + "is not a list of strings" + for i in p: + assert type(i) == str, p_name + "is not a list of strings" + return p + def get_opt_param(config, p_name): p = config.get(p_name) @@ -154,18 +163,24 @@ def parse_input_tensors(config): return input_t_info -def parse_config(config): - model_fname = get_str_param(config, "model_name") - if not model_fname.endswith(".pb"): - sys.exit( - model_fname - + " is not a tensorflow protobuf file. Please supply " - + "a valid tensorflow protobuf model (.pb extension)" - ) - if not os.path.isfile(model_fname): - sys.exit(model_fname + " file does not exist") +def parse_config(config, sample_network=False): + if not sample_network: + model_fname = get_str_param(config, "model_name") + if not model_fname.endswith(".pb"): + sys.exit( + model_fname + + " is not a tensorflow protobuf file. Please supply " + + "a valid tensorflow protobuf model (.pb extension)" + ) + if not os.path.isfile(model_fname): + sys.exit(model_fname + " file does not exist") + else: + network_name = get_str_param(config, "network_name") + run_in_tmux = get_opt_bool_param(config, "run_in_tmux") + target = get_str_param(config, "target").upper() - output_tensors = get_str_list_param(config, "output_tensors") + + output_tensors = get_opt_str_list_param(config, "output_tensors") input_t_info = parse_input_tensors(config) save_weights = get_opt_bool_param(config, "save_weights") @@ -181,7 +196,6 @@ def parse_config(config): disable_trunc = get_opt_bool_param(config, "disable_trunc_opts") params = { - "model_name": model_fname, "input_tensors": input_t_info, "output_tensors": output_tensors, "scale": scale, @@ -195,10 +209,15 @@ def parse_config(config): "disable_garbage_collection": disable_garbage_collection, "disable_trunc_opts": disable_trunc, } + if sample_network: + params["network_name"] = network_name + params["run_in_tmux"] = run_in_tmux + else: + params["model_name"] = model_fname return params -def get_params(config_fname): +def get_params(config_fname, sample_network=False): config = get_config(config_fname) - params = parse_config(config) + params = parse_config(config, sample_network) return params diff --git a/Athos/CompilerScripts/sample_networks/print_stats_2pc.sh b/Athos/CompilerScripts/sample_networks/print_stats_2pc.sh new file mode 100755 index 00000000..5236c617 --- /dev/null +++ b/Athos/CompilerScripts/sample_networks/print_stats_2pc.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +# Authors: Pratik Bhatu. + +# Copyright: +# Copyright (c) 2021 Microsoft Research +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +PARTY=$1 +MODEL_DIR=$2 +MODEL_NAME=$(basename $MODEL_DIR) + +if [ $PARTY -eq 0 ]; +then + PARTYNAME="Server" +elif [ $PARTY -eq 1 ]; +then + PARTYNAME="Client" +fi + +echo "-------------------------------------------------------" +echo " ${MODEL_NAME} results [${PARTYNAME}]" +echo "-------------------------------------------------------" +if [ $PARTY -eq 1 ]; +then + echo "Model outputs:" + echo -e "MPC PORTHOS (2PC) output:\t $(awk '$0==($0+0)' ${MODEL_DIR}/party${PARTY}_mpc_output.out)" + echo -e "Tensorflow output:\t\t $(cat ${MODEL_DIR}/tf_pred.float)" + echo "" +fi +read -r peakmem_kb usertime_s systemtime_s walltime_s <<<$(cat ${MODEL_DIR}/party${PARTY}_stats) +log_wall_t_ms=$(grep "Total time taken =" ${MODEL_DIR}/party${PARTY}_mpc_output.out | grep -o '[0-9]\+[.]*[0-9]\+') +log_wall_t_s=$(echo "scale=2; ${log_wall_t_ms}/1000" | bc) +log_wait_t_s=$(grep "Total wait time =" ${MODEL_DIR}/party${PARTY}_mpc_output.out | grep -o '[0-9]\+[.]*[0-9]\+') +log_wait_t_s=$(echo "scale=2; $log_wait_t_s" | bc) +log_work_t_s=$(echo "scale=2; $log_wall_t_s - $log_wait_t_s" | bc) + +work_percent=$(echo "scale=2; 100 * $log_work_t_s/$log_wall_t_s" | bc) +wait_percent=$(echo "scale=2; 100 * $log_wait_t_s/$log_wall_t_s" | bc) + +wall_m=$(echo "$log_wall_t_s/60" | bc) +wall_s=$(echo "$log_wall_t_s%60/1" | bc) +walltime_m_s="${wall_m}m${wall_s}s" + +peakmem_gb=$(echo "scale=2; $peakmem_kb/1024/1024" | bc) +comm=$(grep "Total data sent" ${MODEL_DIR}/party${PARTY}_mpc_output.out) + +echo "Execution summary for ${PARTYNAME}:" +echo -e "[Communication]:\t\t $comm" +echo -e "Peak Memory Usage:\t\t ${peakmem_kb}KB (${peakmem_gb}GB)" +echo -e "Total time taken:\t\t ${walltime_m_s} ($log_wall_t_s seconds)" +echo -e "Total work time:\t\t ${log_work_t_s} seconds (${work_percent}%)" +echo -e "Time spent waiting:\t\t ${log_wait_t_s} seconds (${wait_percent}%)" +if [ $PARTY -eq 1 ]; +then + echo -e "Time taken by tensorflow:\t $(cat ${MODEL_DIR}/tf_pred.time) seconds" +fi diff --git a/Athos/CompilerScripts/sample_networks/print_stats_3pc.sh b/Athos/CompilerScripts/sample_networks/print_stats_3pc.sh new file mode 100755 index 00000000..5d047f06 --- /dev/null +++ b/Athos/CompilerScripts/sample_networks/print_stats_3pc.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# Authors: Pratik Bhatu. + +# Copyright: +# Copyright (c) 2021 Microsoft Research +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +PARTY=$1 +MODEL_DIR=$2 +MODEL_NAME=$(basename $MODEL_DIR) + +if [ $PARTY -eq 0 ]; +then + PARTYNAME="Client" +elif [ $PARTY -eq 1 ]; +then + PARTYNAME="Server" +else + PARTYNAME="Helper" +fi + +echo "-------------------------------------------------------" +echo " ${MODEL_NAME} results [${PARTYNAME}]" +echo "-------------------------------------------------------" +if [ $PARTY -eq 0 ]; +then + echo "Model outputs:" + echo -e "MPC PORTHOS (3PC) output:\t $(awk '$0==($0+0)' ${MODEL_DIR}/party${PARTY}_mpc_output.out)" + echo -e "Tensorflow output:\t\t $(cat ${MODEL_DIR}/tf_pred.float)" + echo "" +fi +read -r peakmem_kb usertime_s systemtime_s walltime_s <<<$(cat ${MODEL_DIR}/party${PARTY}_stats) +user_percent=$(echo "scale=2; 100*$usertime_s / ($usertime_s + $systemtime_s)" | bc) +system_percent=$(echo "scale=2; 100*$systemtime_s / ($usertime_s + $systemtime_s)" | bc) +wall_user_s=$(echo "scale=2; $user_percent * $walltime_s / 100" | bc) +wall_system_s=$(echo "scale=2; $system_percent * $walltime_s / 100" | bc) +peakmem_gb=$(echo "scale=2; $peakmem_kb/1024/1024" | bc) +comm=$(grep "Communication for execution" ${MODEL_DIR}/party${PARTY}_mpc_output.out) + +echo "Execution summary for ${PARTYNAME}:" +echo -e "$comm" +echo -e "Peak Memory Usage:\t\t ${peakmem_kb} KB (${peakmem_gb}GB)" +echo -e "Total time taken:\t\t ${walltime_s} seconds" +echo -e "Total work time:\t\t ${wall_user_s} seconds (${user_percent}%)" +echo -e "Time spent waiting:\t\t ${wall_system_s} seconds (${system_percent}%)" +if [ $PARTY -eq 0 ]; +then + echo -e "Time taken by tensorflow:\t $(cat ${MODEL_DIR}/tf_pred.time) seconds" +fi \ No newline at end of file diff --git a/Athos/CompilerScripts/sample_networks/print_stats_cpp.sh b/Athos/CompilerScripts/sample_networks/print_stats_cpp.sh new file mode 100755 index 00000000..5866fe0b --- /dev/null +++ b/Athos/CompilerScripts/sample_networks/print_stats_cpp.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# Authors: Pratik Bhatu. + +# Copyright: +# Copyright (c) 2021 Microsoft Research +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +MODEL_DIR=$1 +MODEL_NAME=$(basename $MODEL_DIR) + +echo "-------------------------------------------------------" +echo " ${MODEL_NAME} results" +echo "-------------------------------------------------------" +echo "Model outputs:" +echo -e "MPC Cleartext (CPP) output:\t $(awk '$0==($0+0)' ${MODEL_DIR}/mpc_output.out)" +echo -e "Tensorflow output:\t\t $(cat ${MODEL_DIR}/tf_pred.float)" +echo "" + +read -r peakmem_kb usertime_s systemtime_s walltime_s <<<$(cat ${MODEL_DIR}/party0_stats) +user_percent=$(echo "scale=2; 100*$usertime_s / ($usertime_s + $systemtime_s)" | bc) +system_percent=$(echo "scale=2; 100*$systemtime_s / ($usertime_s + $systemtime_s)" | bc) +wall_user_s=$(echo "scale=2; $user_percent * $walltime_s / 100" | bc) +wall_system_s=$(echo "scale=2; $system_percent * $walltime_s / 100" | bc) +peakmem_gb=$(echo "scale=2; $peakmem_kb/1024/1024" | bc) + +echo "Execution summary:" +echo -e "Peak Memory Usage:\t\t ${peakmem_kb} KB (${peakmem_gb}GB)" +echo -e "Total time taken:\t\t ${walltime_s} seconds" +echo -e "Total work time:\t\t ${wall_user_s} seconds (${user_percent}%)" +echo -e "Time spent waiting:\t\t ${wall_system_s} seconds (${system_percent}%)" +echo -e "Time taken by tensorflow:\t $(cat ${MODEL_DIR}/tf_pred.time) seconds" \ No newline at end of file diff --git a/Athos/CompilerScripts/sample_networks/run_demo_2pc.sh b/Athos/CompilerScripts/sample_networks/run_demo_2pc.sh new file mode 100755 index 00000000..0b0db262 --- /dev/null +++ b/Athos/CompilerScripts/sample_networks/run_demo_2pc.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# Authors: Pratik Bhatu. + +# Copyright: +# Copyright (c) 2021 Microsoft Research +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +MODEL_DIR=$1 +MODEL_BINARY_PATH=$2 +MODEL_INPUT_PATH=$3 +MODEL_WEIGHT_PATH=$4 + +MODEL_NAME=$(basename ${MODEL_DIR}) +SESSION_NAME=${MODEL_NAME} + +tmux has-session -t "${SESSION_NAME}" > /dev/null 2>&1 +if [ "$?" -eq 0 ]; then + echo "Killing existing tmux ${SESSION_NAME} session" + tmux kill-session -t "${SESSION_NAME}" +fi + +tmux new-session -s "${SESSION_NAME}" -d +tmux split-window -v -t "${SESSION_NAME}:0.0" + +tmux send-keys -t "${SESSION_NAME}:0.0" "clear" Enter +tmux send-keys -t "${SESSION_NAME}:0.1" "clear" Enter + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +TIME_CMD="/usr/bin/time --format \"%M %U %S %e\"" +PARTY0_IP="127.0.0.1" +PORT=12345 + +PARTY0_RUN_CMD="${MODEL_BINARY_PATH} r=1 p=${PORT} < ${MODEL_WEIGHT_PATH}" +PARTY1_RUN_CMD="${MODEL_BINARY_PATH} r=2 ip=${PARTY0_IP} p=${PORT} < ${MODEL_INPUT_PATH}" + +PARTY0_DUMP_CMD="> ${MODEL_DIR}/party0_mpc_output.out 2> ${MODEL_DIR}/party0_stats" +PARTY1_DUMP_CMD="> ${MODEL_DIR}/party1_mpc_output.out 2> ${MODEL_DIR}/party1_stats" + + +PARTY0_FINAL_CMD="${TIME_CMD} ${PARTY0_RUN_CMD} ${PARTY0_DUMP_CMD}" +PARTY1_FINAL_CMD="${TIME_CMD} ${PARTY1_RUN_CMD} ${PARTY1_DUMP_CMD}" + +tmux send-keys -t "${SESSION_NAME}:0.0" "${PARTY0_FINAL_CMD}" Enter +tmux send-keys -t "${SESSION_NAME}:0.1" "${PARTY1_FINAL_CMD}" Enter + +PARTY0_FINAL_CMD="clear; ${SCRIPT_DIR}/print_stats_2pc.sh 0 ${MODEL_DIR}" +PARTY1_FINAL_CMD="clear; ${SCRIPT_DIR}/print_stats_2pc.sh 1 ${MODEL_DIR}" + +tmux send-keys -t "${SESSION_NAME}:0.0" "${PARTY0_FINAL_CMD}" Enter +tmux send-keys -t "${SESSION_NAME}:0.1" "${PARTY1_FINAL_CMD}" Enter \ No newline at end of file diff --git a/Athos/CompilerScripts/sample_networks/run_demo_3pc.sh b/Athos/CompilerScripts/sample_networks/run_demo_3pc.sh new file mode 100755 index 00000000..6a4d7226 --- /dev/null +++ b/Athos/CompilerScripts/sample_networks/run_demo_3pc.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +# Authors: Pratik Bhatu. + +# Copyright: +# Copyright (c) 2021 Microsoft Research +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +MODEL_DIR=$1 +MODEL_BINARY_PATH=$2 +MODEL_INPUT_PATH=$3 +MODEL_WEIGHT_PATH=$4 + +MODEL_NAME=$(basename ${MODEL_DIR}) +SESSION_NAME=${MODEL_NAME} + +tmux has-session -t "${SESSION_NAME}" > /dev/null 2>&1 +if [ "$?" -eq 0 ]; then + echo "Killing existing tmux ${SESSION_NAME} session" + tmux kill-session -t "${SESSION_NAME}" +fi + +tmux new-session -s "${SESSION_NAME}" -d +tmux split-window -v -t "${SESSION_NAME}:0.0" +tmux split-window -h -t "${SESSION_NAME}:0.1" + +tmux send-keys -t "${SESSION_NAME}:0.0" "clear" Enter +tmux send-keys -t "${SESSION_NAME}:0.1" "clear" Enter +tmux send-keys -t "${SESSION_NAME}:0.2" "clear" Enter + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +PORTHOS_DIR=$(realpath "${SCRIPT_DIR}/../../../Porthos") +ADDRESS_FILE="${PORTHOS_DIR}/files/addresses" +KEYS_DIR="${PORTHOS_DIR}/files/keys" + +TIME_CMD="/usr/bin/time --format \"%M %U %S %e\"" +PARTY0_RUN_CMD="${MODEL_BINARY_PATH} 0 ${ADDRESS_FILE} ${KEYS_DIR} < ${MODEL_INPUT_PATH}" +PARTY1_RUN_CMD="${MODEL_BINARY_PATH} 1 ${ADDRESS_FILE} ${KEYS_DIR} < ${MODEL_WEIGHT_PATH}" +PARTY2_RUN_CMD="${MODEL_BINARY_PATH} 2 ${ADDRESS_FILE} ${KEYS_DIR}" + +PARTY0_DUMP_CMD="> ${MODEL_DIR}/party0_mpc_output.out 2> ${MODEL_DIR}/party0_stats" +PARTY1_DUMP_CMD="> ${MODEL_DIR}/party1_mpc_output.out 2> ${MODEL_DIR}/party1_stats" +PARTY2_DUMP_CMD="> ${MODEL_DIR}/party2_mpc_output.out 2> ${MODEL_DIR}/party2_stats" + +PARTY0_FINAL_CMD="${TIME_CMD} ${PARTY0_RUN_CMD} ${PARTY0_DUMP_CMD}" +PARTY1_FINAL_CMD="${TIME_CMD} ${PARTY1_RUN_CMD} ${PARTY1_DUMP_CMD}" +PARTY2_FINAL_CMD="${TIME_CMD} ${PARTY2_RUN_CMD} ${PARTY2_DUMP_CMD}" + +tmux send-keys -t "${SESSION_NAME}:0.0" "${PARTY0_FINAL_CMD}" Enter +tmux send-keys -t "${SESSION_NAME}:0.1" "${PARTY1_FINAL_CMD}" Enter +tmux send-keys -t "${SESSION_NAME}:0.2" "${PARTY2_FINAL_CMD}" Enter + +PARTY0_FINAL_CMD="clear; ${SCRIPT_DIR}/print_stats_3pc.sh 0 ${MODEL_DIR}" +PARTY1_FINAL_CMD="clear; ${SCRIPT_DIR}/print_stats_3pc.sh 1 ${MODEL_DIR}" +PARTY2_FINAL_CMD="clear; ${SCRIPT_DIR}/print_stats_3pc.sh 2 ${MODEL_DIR}" + +tmux send-keys -t "${SESSION_NAME}:0.0" "${PARTY0_FINAL_CMD}" Enter +tmux send-keys -t "${SESSION_NAME}:0.1" "${PARTY1_FINAL_CMD}" Enter +tmux send-keys -t "${SESSION_NAME}:0.2" "${PARTY2_FINAL_CMD}" Enter \ No newline at end of file diff --git a/Athos/CompilerScripts/sample_networks/run_demo_cpp.sh b/Athos/CompilerScripts/sample_networks/run_demo_cpp.sh new file mode 100755 index 00000000..d52953ce --- /dev/null +++ b/Athos/CompilerScripts/sample_networks/run_demo_cpp.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +# Authors: Pratik Bhatu. + +# Copyright: +# Copyright (c) 2021 Microsoft Research +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +MODEL_DIR=$1 +MODEL_BINARY_PATH=$2 +MODEL_INPUT_PATH=$3 +MODEL_WEIGHT_PATH=$4 + +MODEL_NAME=$(basename ${MODEL_DIR}) +SESSION_NAME=${MODEL_NAME} + +tmux has-session -t "${SESSION_NAME}" > /dev/null 2>&1 +if [ "$?" -eq 0 ]; then + echo "Killing existing tmux ${SESSION_NAME} session" + tmux kill-session -t "${SESSION_NAME}" +fi + +tmux new-session -s "${SESSION_NAME}" -d +tmux send-keys -t "${SESSION_NAME}:0.0" "clear" Enter + +TIME_CMD="/usr/bin/time --format \"%M %U %S %e\"" +RUN_CMD="${MODEL_BINARY_PATH} < <(cat ${MODEL_INPUT_PATH} ${MODEL_WEIGHT_PATH})" +DUMP_CMD="> ${MODEL_DIR}/mpc_output.out 2> ${MODEL_DIR}/party0_stats" +FINAL_CMD="${TIME_CMD} ${RUN_CMD} ${DUMP_CMD}" + +tmux send-keys -t "${SESSION_NAME}:0.0" "${FINAL_CMD}" Enter + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +FINAL_CMD="clear; ${SCRIPT_DIR}/print_stats_cpp.sh ${MODEL_DIR}" + +tmux send-keys -t "${SESSION_NAME}:0.0" "${FINAL_CMD}" Enter \ No newline at end of file diff --git a/Athos/HelperScripts/SetupCIFAR10.sh b/Athos/HelperScripts/SetupCIFAR10.sh old mode 100644 new mode 100755 index 2dc3b1c0..be914e19 --- a/Athos/HelperScripts/SetupCIFAR10.sh +++ b/Athos/HelperScripts/SetupCIFAR10.sh @@ -20,7 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -cifar10DownloadLink="https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" -axel -a -n 3 -c --output CIFAR10 "$cifar10DownloadLink" -cd CIFAR10 -tar -xvzf cifar-10-python.tar.gz --directory=. +if [ ! -f "CIFAR10/cifar-10-python.tar.gz" ]; then + wget "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" -P ./CIFAR10 + cd CIFAR10 + tar -xvzf cifar-10-python.tar.gz --directory=. +fi \ No newline at end of file diff --git a/Athos/Networks/.gitignore b/Athos/Networks/.gitignore new file mode 100644 index 00000000..a069bbd2 --- /dev/null +++ b/Athos/Networks/.gitignore @@ -0,0 +1,3 @@ +*/party*_stats +*/tf_pred.float +*/tf_pred.time diff --git a/Athos/Networks/ChestXRay/ChestXRay_tf_main.py b/Athos/Networks/ChestXRay/ChestXRay_tf_main.py index 4e30b837..312ff3ae 100644 --- a/Athos/Networks/ChestXRay/ChestXRay_tf_main.py +++ b/Athos/Networks/ChestXRay/ChestXRay_tf_main.py @@ -24,6 +24,18 @@ import cv2, numpy, sys, os, argparse, time import tensorflow as tf + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +from tensorflow.python.util import deprecation +deprecation._PRINT_DEPRECATION_WARNINGS = False +try: + from tensorflow.python.util import module_wrapper as deprecation +except ImportError: + from tensorflow.python.util import deprecation_wrapper as deprecation +deprecation._PER_MODULE_WARNING_LIMIT = 0 + + sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) import DumpTFMtData @@ -83,7 +95,13 @@ def parseArgs(): predictions = sess.run(output_tensor, feed_dict=feed_dict) end_time = time.time() print("*************** Done Prediction****************") + duration = end_time - start_time + print("Time taken in inference : ", duration) print(predictions) + with open('tf_pred.float','w+') as f: + f.write(DumpTFMtData.numpy_float_array_to_float_val_str(predictions)) + with open('tf_pred.time','w') as f: + f.write(str(round(duration, 2))) trainVarsName = [] for node in optimized_graph_def.node: @@ -91,9 +109,10 @@ def parseArgs(): trainVarsName.append(node.name) trainVars = list(map(lambda x : tf.get_default_graph().get_operation_by_name(x).outputs[0] , trainVarsName)) if args.savePreTrainedWeightsInt: - DumpTFMtData.dumpTrainedWeights(sess, trainVars, 'ChestXRay_weights_{0}.inp'.format(args.scalingFac), args.scalingFac, 'w') + DumpTFMtData.dumpTrainedWeights(sess, trainVars, "model_weights_scale_{}.inp".format(args.scalingFac), args.scalingFac, 'w') if args.savePreTrainedWeightsFloat: - DumpTFMtData.dumpTrainedWeightsFloat(sess, trainVars, 'ChestXRay_weights_float.inp', 'w') + DumpTFMtData.dumpTrainedWeightsFloat(sess, trainVars, 'model_weights_float.inp', 'w') if args.saveImgAndWtData: - DumpTFMtData.dumpImgAndWeightsDataSeparate(sess, images[0], trainVars, 'ChestXRay_img_{0}.inp'.format(args.scalingFac), 'ChestXRay_weights_{0}.inp'.format(args.scalingFac), args.scalingFac) - + DumpTFMtData.dumpImgAndWeightsDataSeparate(sess, images[0], trainVars, "model_input_scale_{}.inp".format(args.scalingFac), + "model_weights_scale_{}.inp".format(args.scalingFac), args.scalingFac) + diff --git a/Athos/Networks/ChestXRay/setup_and_run.sh b/Athos/Networks/ChestXRay/setup_and_run.sh new file mode 100755 index 00000000..5bd52758 --- /dev/null +++ b/Athos/Networks/ChestXRay/setup_and_run.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# Authors: Pratik Bhatu. + +# Copyright: +# Copyright (c) 2021 Microsoft Research +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +if [ -z "$1" ]; then + scale=12 +else + scale=$1 +fi + +filename="chexray_14_weights_712split_epoch_054_val_loss_191.2588.hdf5" +if [ ! -f "PreTrainedModel/KerasModel/${filename}" ]; then + echo "----------------------------------" + echo "Downloading trained ChestXRay Model" + echo "----------------------------------" + curl "https://chestxray.blob.core.windows.net/chestxraytutorial/tutorial_xray/chexray_14_weights_712split_epoch_054_val_loss_191.2588.hdf5" -o PreTrainedModel/KerasModel + if [ ! -f "PreTrainedModel/TFModel/model.pb" ]; then + cd PreTrainedModel + if [ ! -d "keras_to_tensorflow" ]; then + git clone https://github.com/amir-abdi/keras_to_tensorflow + fi + echo -e "Starting keras to TF model conversion....\n" + python3 keras_to_tensorflow/keras_to_tensorflow.py --output_meta_ckpt=True --save_graph_def=True --input_model="KerasModel/${filename}" --output_model="TFModel/model.pb" + fi +fi + +#exit +echo -e "\n\n" +echo "--------------------------------------------------------------------------------" +echo "Running ChestXRay network and dumping computation graph, inputs and model weights" +echo "This will take some time" +echo "--------------------------------------------------------------------------------" +echo -e "\n\n" +python3 ChestXRay_tf_main.py --runPrediction True --scalingFac $scale --saveImgAndWtData True +echo -e "\n\n" \ No newline at end of file diff --git a/Athos/Networks/DenseNet/.gitignore b/Athos/Networks/DenseNet/.gitignore new file mode 100644 index 00000000..62401fc5 --- /dev/null +++ b/Athos/Networks/DenseNet/.gitignore @@ -0,0 +1 @@ +tf-densenet121.tar.gz diff --git a/Athos/Networks/DenseNet/DenseNet_main.py b/Athos/Networks/DenseNet/DenseNet_main.py index 2a1c56be..4531aa5e 100644 --- a/Athos/Networks/DenseNet/DenseNet_main.py +++ b/Athos/Networks/DenseNet/DenseNet_main.py @@ -28,6 +28,16 @@ import tensorflow as tf import _pickle as pickle +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +from tensorflow.python.util import deprecation +deprecation._PRINT_DEPRECATION_WARNINGS = False +try: + from tensorflow.python.util import module_wrapper as deprecation +except ImportError: + from tensorflow.python.util import deprecation_wrapper as deprecation +deprecation._PER_MODULE_WARNING_LIMIT = 0 + import nets_factory sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) import DumpTFMtData @@ -98,8 +108,14 @@ def parseArgs(): predictions = sess.run(output_tensor, feed_dict=feed_dict) end_time = time.time() print("*************** Done Prediction****************") + duration = end_time - start_time + print("Time taken in inference : ", duration) + with open('tf_pred.float','w+') as f: + f.write(DumpTFMtData.numpy_float_array_to_float_val_str(predictions)) + with open('tf_pred.time','w') as f: + f.write(str(round(duration, 2))) - print(predictions) + print("Prediction = ", predictions) trainVarsName = [] for node in optimized_graph_def.node: @@ -107,9 +123,10 @@ def parseArgs(): trainVarsName.append(node.name) trainVars = list(map(lambda x : tf.get_default_graph().get_operation_by_name(x).outputs[0] , trainVarsName)) if args.savePreTrainedWeightsInt: - DumpTFMtData.dumpTrainedWeightsInt(sess, trainVars, 'DenseNet_weights.inp', args.scalingFac, 'w') + DumpTFMtData.dumpTrainedWeightsInt(sess, trainVars, "model_weights_scale_{}.inp".format(args.scalingFac), args.scalingFac, 'w') if args.savePreTrainedWeightsFloat: - DumpTFMtData.dumpTrainedWeightsFloat(sess, trainVars, 'DenseNet_weights_float.inp', 'w') + DumpTFMtData.dumpTrainedWeightsFloat(sess, trainVars, 'model_weights_float.inp', 'w') if args.saveImgAndWtData: - DumpTFMtData.dumpImgAndWeightsDataSeparate(sess, images[0], trainVars, 'DenseNet_img.inp', 'DenseNet_weights.inp', args.scalingFac) + DumpTFMtData.dumpImgAndWeightsDataSeparate(sess, images[0], trainVars, "model_input_scale_{}.inp".format(args.scalingFac), + "model_weights_scale_{}.inp".format(args.scalingFac), args.scalingFac) diff --git a/Athos/Networks/DenseNet/setup_and_run.sh b/Athos/Networks/DenseNet/setup_and_run.sh new file mode 100755 index 00000000..0ae2784b --- /dev/null +++ b/Athos/Networks/DenseNet/setup_and_run.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Authors: Pratik Bhatu. + +# Copyright: +# Copyright (c) 2021 Microsoft Research +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +if [ -z "$1" ]; then + scale=12 +else + scale=$1 +fi + +filename="tf-densenet121.tar.gz" +if [ ! -f "PreTrainedModel/${filename}" ]; then + echo "----------------------------------" + echo "Downloading trained DenseNet Model" + echo "----------------------------------" + fileid=0B_fUSpodN0t0eW1sVk1aeWREaDA + curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null + curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o PreTrainedModel/${filename} + rm cookie +fi +if [[ ! -f "PreTrainedModel/tf-densenet121.ckpt.data-00000-of-00001" || ! -f "PreTrainedModel/tf-densenet121.ckpt.index" || ! -f "PreTrainedModel/tf-densenet121.ckpt.meta" ]]; then + cd PreTrainedModel + tar -xvzf ${filename} + cd - +fi +exit +echo -e "\n\n" +echo "--------------------------------------------------------------------------------" +echo "Running DenseNet network and dumping computation graph, inputs and model weights" +echo "This will take some time" +echo "--------------------------------------------------------------------------------" +echo -e "\n\n" +python3 DenseNet_main.py --runPrediction True --scalingFac $scale --saveImgAndWtData True +echo -e "\n\n" \ No newline at end of file diff --git a/Athos/Networks/ResNet/ResNet_main.py b/Athos/Networks/ResNet/ResNet_main.py index 7b9ed844..cd1d9d6b 100644 --- a/Athos/Networks/ResNet/ResNet_main.py +++ b/Athos/Networks/ResNet/ResNet_main.py @@ -33,6 +33,16 @@ import _pickle as pickle import Resnet_Model +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +from tensorflow.python.util import deprecation +deprecation._PRINT_DEPRECATION_WARNINGS = False +try: + from tensorflow.python.util import module_wrapper as deprecation +except ImportError: + from tensorflow.python.util import deprecation_wrapper as deprecation +deprecation._PER_MODULE_WARNING_LIMIT = 0 + sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) import DumpTFMtData @@ -162,11 +172,11 @@ def infer(savePreTrainedWeightsInt, savePreTrainedWeightsFloat, scalingFac, runP end_time = time.time() print("*************** Done Prediction****************") duration = end_time - start_time - print("Time taken in prediction : ", duration) - with open('ResNet_tf_pred.float','w+') as f: + print("Time taken in inference : ", duration) + with open('tf_pred.float','w+') as f: f.write(DumpTFMtData.numpy_float_array_to_float_val_str(predictions)) - with open('ResNet_tf_pred.time','w') as f: - f.write(str(round(duration, 2))) + with open('tf_pred.time','w') as f: + f.write(str(round(duration, 2))) trainVarsName = [] for node in optimized_graph_def.node: @@ -174,11 +184,12 @@ def infer(savePreTrainedWeightsInt, savePreTrainedWeightsFloat, scalingFac, runP trainVarsName.append(node.name) trainVars = list(map(lambda x : tf.get_default_graph().get_operation_by_name(x).outputs[0] , trainVarsName)) if savePreTrainedWeightsInt: - DumpTFMtData.dumpTrainedWeights(sess, trainVars, 'ResNet_weights.inp', scalingFac, 'w') + DumpTFMtData.dumpTrainedWeights(sess, trainVars, "model_weights_scale_{}.inp".format(scalingFac), scalingFac, 'w') if savePreTrainedWeightsFloat: - DumpTFMtData.dumpTrainedWeightsFloat(sess, trainVars, 'ResNet_weights_float.inp', 'w') + DumpTFMtData.dumpTrainedWeightsFloat(sess, trainVars, 'model_weights_float.inp', 'w') if saveImgAndWtData: - DumpTFMtData.dumpImgAndWeightsDataSeparate(sess, images[0], trainVars, 'ResNet_img.inp', 'ResNet_weights.inp', scalingFac) + DumpTFMtData.dumpImgAndWeightsDataSeparate(sess, images[0], trainVars, "model_input_scale_{}.inp".format(scalingFac), + "model_weights_scale_{}.inp".format(scalingFac), scalingFac) return predictions def parseArgs(): @@ -201,8 +212,8 @@ def main(): args.scalingFac, args.runPrediction, args.saveImgAndWtData) - print(pred) + print("Prediction = ", pred) return pred if __name__=='__main__': - pred = main() + pred = main() diff --git a/Athos/Networks/ResNet/setup_and_run.sh b/Athos/Networks/ResNet/setup_and_run.sh new file mode 100755 index 00000000..d0a04851 --- /dev/null +++ b/Athos/Networks/ResNet/setup_and_run.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# Authors: Pratik Bhatu. + +# Copyright: +# Copyright (c) 2021 Microsoft Research +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +if [ -z "$1" ]; then + scale=12 +else + scale=$1 +fi +filename=resnet_v2_fp32_savedmodel_NHWC.tar.gz +if [ ! -f "PreTrainedModel/${filename}" ]; then + echo "--------------------------------" + echo "Downloading trained ResNet Model" + echo "--------------------------------" + wget "http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NHWC.tar.gz" -P ./PreTrainedModel +fi + +if [[ ! -d "PreTrainedModel/resnet_v2_fp32_savedmodel_NHWC" ]]; then + cd PreTrainedModel + tar -xvzf ${filename} + cd - +fi +exit +echo -e "\n\n" +echo "--------------------------------------------------------------------------------" +echo "Running ResNet network and dumping computation graph, inputs and model weights" +echo "This will take some time" +echo "--------------------------------------------------------------------------------" +echo -e "\n\n" +python3 ResNet_main.py --runPrediction True --scalingFac $scale --saveImgAndWtData True +echo -e "\n\n" \ No newline at end of file diff --git a/Athos/Networks/SqueezeNetCIFAR10/Squeezenet_model.py b/Athos/Networks/SqueezeNetCIFAR10/Squeezenet_model.py index f79fbb67..8e59916a 100644 --- a/Athos/Networks/SqueezeNetCIFAR10/Squeezenet_model.py +++ b/Athos/Networks/SqueezeNetCIFAR10/Squeezenet_model.py @@ -38,6 +38,18 @@ import numpy import matplotlib import tensorflow as tf + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +from tensorflow.python.util import deprecation +deprecation._PRINT_DEPRECATION_WARNINGS = False +try: + from tensorflow.python.util import module_wrapper as deprecation +except ImportError: + from tensorflow.python.util import deprecation_wrapper as deprecation +deprecation._PER_MODULE_WARNING_LIMIT = 0 + + sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) import DumpTFMtData from argparse import ArgumentParser @@ -487,6 +499,10 @@ def infer(sqn, sess, images, labels, restoreModelPath, findAccOrArgMaxOrPredVal= print("Time taken in prediction : ", duration) print("Inference result = ", predictions) + with open('tf_pred.float','w+') as f: + f.write(DumpTFMtData.numpy_float_array_to_float_val_str(predictions)) + with open('tf_pred.time','w') as f: + f.write(str(round(duration, 2))) return predictions def getTrainedWeightsStrForm(sess, evalTensors, scalingFac): @@ -600,8 +616,8 @@ def main(): findAndSaveCorrectTestImg(pred, testing_features, testing_labels, './testPred/CorrectImg/', './testPred/IncorrectImg/', './testPred/TestInputs/', sess, sqn, scalingFac) if (inp == 'savegraphAndDataBatch' or inp=='testSingleTestInpAndSaveData'): - imgFileName = 'SqNet_CIFAR_img.inp' - weightsFileName = 'SqNet_CIFAR_weights.inp' + imgFileName = "model_input_scale_{}.inp".format(scalingFac) + weightsFileName = "model_weights_scale_{}.inp".format(scalingFac) for ii,curFeature in enumerate(testing_features): if ii == 0 : DumpTFMtData.dumpImageDataInt(curFeature, imgFileName, scalingFac, 'w') @@ -611,4 +627,4 @@ def main(): if __name__ == '__main__': main() - + diff --git a/Athos/Networks/SqueezeNetCIFAR10/setup_and_run.sh b/Athos/Networks/SqueezeNetCIFAR10/setup_and_run.sh new file mode 100755 index 00000000..d84b9bc4 --- /dev/null +++ b/Athos/Networks/SqueezeNetCIFAR10/setup_and_run.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# Authors: Pratik Bhatu. + +# Copyright: +# Copyright (c) 2021 Microsoft Research +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +if [ -z "$1" ]; then + scale=12 +else + scale=$1 +fi + +if [ ! -f "PreProcessedImages/preprocess_batch_1.p" ] ; then + echo -e "\n\n" + echo "--------------------------------------------------------------------------------" + echo "One-time set up of CIFAR10 dataset for SqueezeNet Training" + echo "This will take some time" + echo "--------------------------------------------------------------------------------" + echo -e "\n\n" + cd "../../HelperScripts" + ./SetupCIFAR10.sh + cd - + python3 Util.py +fi + +if [ ! -f "TrainedModel/model.meta" ] ; then + echo -e "\n\n" + echo "--------------------------------------------------------------------------------" + echo "Training SqueezeNet network for 1 epoch" + echo "This will take some time" + echo "--------------------------------------------------------------------------------" + echo -e "\n\n" + python3 Squeezenet_model.py train +fi + +echo -e "\n\n" +echo "--------------------------------------------------------------------------------" +echo "Running SqueezeNetCIFAR10 network and dumping computation graph, inputs and model weights" +echo "This will take some time" +echo "--------------------------------------------------------------------------------" +echo -e "\n\n" +python3 Squeezenet_model.py savegraph +python3 Squeezenet_model.py testSingleTestInpAndSaveData 1 1 +echo -e "\n\n" \ No newline at end of file diff --git a/Athos/Networks/SqueezeNetImgNet/setup_and_run.sh b/Athos/Networks/SqueezeNetImgNet/setup_and_run.sh new file mode 100755 index 00000000..f141763a --- /dev/null +++ b/Athos/Networks/SqueezeNetImgNet/setup_and_run.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# Authors: Pratik Bhatu. + +# Copyright: +# Copyright (c) 2021 Microsoft Research +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +if [ -z "$1" ]; then + scale=12 +else + scale=$1 +fi +filename=sqz_full.mat +if [ ! -f "PreTrainedModel/${filename}" ]; then + echo "--------------------------------" + echo "Downloading trained SqueezeNet Model" + echo "--------------------------------" + wget "https://github.com/avoroshilov/tf-squeezenet/raw/master/sqz_full.mat" -P ./PreTrainedModel +fi + +echo -e "\n\n" +echo "--------------------------------------------------------------------------------" +echo "Running SqueezeNetImgNet network and dumping computation graph, inputs and model weights" +echo "This will take some time" +echo "--------------------------------------------------------------------------------" +echo -e "\n\n" +python3 squeezenet_main.py --in ./SampleImages/n02109961_36.JPEG --saveTFMetadata True +python3 squeezenet_main.py --in ./SampleImages/n02109961_36.JPEG --scalingFac $scale --saveImgAndWtData True +echo -e "\n\n" \ No newline at end of file diff --git a/Athos/Networks/SqueezeNetImgNet/squeezenet_main.py b/Athos/Networks/SqueezeNetImgNet/squeezenet_main.py index 57133553..8b6cbc9b 100644 --- a/Athos/Networks/SqueezeNetImgNet/squeezenet_main.py +++ b/Athos/Networks/SqueezeNetImgNet/squeezenet_main.py @@ -9,6 +9,16 @@ from PIL import Image from argparse import ArgumentParser +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +from tensorflow.python.util import deprecation +deprecation._PRINT_DEPRECATION_WARNINGS = False +try: + from tensorflow.python.util import module_wrapper as deprecation +except ImportError: + from tensorflow.python.util import deprecation_wrapper as deprecation +deprecation._PER_MODULE_WARNING_LIMIT = 0 + sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) import DumpTFMtData @@ -258,18 +268,27 @@ def main(): optimized_graph_def = DumpTFMtData.save_graph_metadata(output_tensor, sess, feed_dict) else: # Classifying + print("*************** Starting Prediction****************") + start_time = time.time() sqz_class = final_class.eval(feed_dict)[0][0][0] + end_time = time.time() + print("*************** Done Prediction****************") + duration = end_time - start_time + print("Time taken in inference : ", duration) + with open('tf_pred.float','w+') as f: + f.write(DumpTFMtData.numpy_float_array_to_float_val_str(sqz_class)) + with open('tf_pred.time','w') as f: + f.write(str(round(duration, 2))) - print(sqz_class) # Outputting result print("\nclass: [%d] '%s'" % (sqz_class, classes[sqz_class])) if options.savePreTrainedWeightsInt: - DumpTFMtData.dumpTrainedWeightsInt(sess, all_weights, 'SqNet_weights.inp', options.scalingFac, 'w', alreadyEvaluated=True) + DumpTFMtData.dumpTrainedWeightsInt(sess, all_weights, "model_weights_scale_{}.inp".format(options.scalingFac), options.scalingFac, 'w', alreadyEvaluated=True) if options.savePreTrainedWeightsFloat: - DumpTFMtData.dumpTrainedWeightsFloat(sess, all_weights, 'SqNet_weights_float.inp', 'w', alreadyEvaluated=True) + DumpTFMtData.dumpTrainedWeightsFloat(sess, all_weights, 'model_weights_float.inp', 'w', alreadyEvaluated=True) if options.saveImgAndWtData: - DumpTFMtData.dumpImgAndWeightsDataSeparate(sess, imageData, all_weights, 'SqNet_img.inp', 'SqNet_weights.inp', options.scalingFac, alreadyEvaluated=True) + DumpTFMtData.dumpImgAndWeightsDataSeparate(sess, imageData, all_weights, "model_input_scale_{}.inp".format(options.scalingFac), "model_weights_scale_{}.inp".format(options.scalingFac), options.scalingFac, alreadyEvaluated=True) if __name__ == '__main__': main() \ No newline at end of file diff --git a/Athos/Networks/sample_network.config b/Athos/Networks/sample_network.config new file mode 100644 index 00000000..266eb867 --- /dev/null +++ b/Athos/Networks/sample_network.config @@ -0,0 +1,5 @@ +{ + "network_name":"ResNet", + "target":"PORTHOS", + "run_in_tmux": true +} diff --git a/Athos/SeeDot/Compiler.py b/Athos/SeeDot/Compiler.py index 258c95e7..d81d536f 100644 --- a/Athos/SeeDot/Compiler.py +++ b/Athos/SeeDot/Compiler.py @@ -118,7 +118,7 @@ def run(self): print("Relu-maxpool optimization done.") if not(Util.Config.disableLivenessOpti): - print("Performing Garbage colelction...") + print("Performing Garbage collection...") mtdAST = MtdAST() GC = GarbageCollector.GarbageCollector(ast) GC.run([mtdAST]) diff --git a/Athos/TFCompiler/ProcessTFGraph.py b/Athos/TFCompiler/ProcessTFGraph.py index 854cb80c..64e84d2d 100644 --- a/Athos/TFCompiler/ProcessTFGraph.py +++ b/Athos/TFCompiler/ProcessTFGraph.py @@ -141,7 +141,10 @@ def simplifyGraph(graph): def process_tf_graph(filename): sys.setrecursionlimit(10000) - folderName = os.path.dirname(filename) + if os.path.isfile(filename): + folderName = os.path.dirname(filename) + elif os.path.isdir(filename): + folderName = filename graphFileName = os.path.join(folderName, 'graphDef.mtdata') graph = Graph.Graph() with open(graphFileName) as file: From 921fb4dfd8e65bbf285dd44dafd177378807d0cf Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 20 Jan 2021 10:51:03 +0530 Subject: [PATCH 37/72] Add script to install dependencies and compile everything --- README.md | 10 +++++- setup_env_and_build.sh | 69 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 setup_env_and_build.sh diff --git a/README.md b/README.md index 6c8e0a28..79cae3c9 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,15 @@ Each one of the above is independent and usable in their own right and more info With these components in place, we are able to run for the first time secure inference on the [ImageNet dataset]([http://www.image-net.org) with the pre-trained models of the following deep neural nets: ResNet-50, DenseNet-121 and SqueezeNet for ImageNet. -For setup instructions, please refer to each of the components' readme. We plan to release a docker version of the system as well which will make the system easier to setup. +For setup instructions, please refer to each of the components' readme. + +Alternatively you can use the **setup_env_and_build.sh** script. It installs dependencies and builds each component. It also creates a virtual environment in a *mpc_venv* folder with all the required packages. + +Please do ``source mpc_venv/bin/activate`` before using the toolchain. + +We plan to release a docker version of the system as well which will make the system easier to setup. + +We plan to release a docker version of the system as well which will make the system easier to setup. ## Wiki Wiki section of this repository provides coding practices and examples to get started with EzPC. diff --git a/setup_env_and_build.sh b/setup_env_and_build.sh new file mode 100644 index 00000000..afa8839a --- /dev/null +++ b/setup_env_and_build.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +# Authors: Pratik Bhatu. + +# Copyright: +# Copyright (c) 2021 Microsoft Research +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +sudo add-apt-repository ppa:deadsnakes/ppa +sudo apt update +sudo apt install -y build-essential make cmake libgmp-dev libglib2.0-dev libssl-dev libboost-all-dev m4 python3.7 + +sudo apt install unzip bubblewrap +sh <(curl -sL https://raw.githubusercontent.com/ocaml/opam/master/shell/install.sh) +# environment setup +opam init +eval `opam env` +# install given version of the compiler +opam switch create 4.10.0 +eval `opam env` +# check if we got what we wanted +which ocaml +ocaml -version +opam install -y Stdint +opam install -y menhir +opam install -y ocamlbuild +opam install -y ocamlfind + +#Virtual environment +sudo apt install -y python3.7-venv +python3.7 -m venv mpc_venv +source mpc_venv/bin/activate +pip install -U pip +pip install tensorflow==1.15.0 keras==2.3.0 scipy==1.1.0 matplotlib +pip install pytest pytest-cov + +# Now we build all the components. +ROOT="$(pwd)" +#Build Ezpc +cd EzPC/EzPC +eval `opam env` +make +#Build Porthos +cd $ROOT/Porthos +./setup-eigen.sh +mkdir -p src/build +cd src/build +cmake ../ +make -j +#Build SCI +cd $ROOT/SCI +mkdir -p build +cd build +cmake -DBUILD_NETWORKS=ON ../ +make -j \ No newline at end of file From f423c5cbf46195c177438e91399a7f65f3cea77e Mon Sep 17 00:00:00 2001 From: Pratik Bhatu Date: Wed, 20 Jan 2021 10:52:04 +0530 Subject: [PATCH 38/72] Update README.md --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 79cae3c9..0b7d35d8 100644 --- a/README.md +++ b/README.md @@ -37,8 +37,6 @@ Please do ``source mpc_venv/bin/activate`` before using the toolchain. We plan to release a docker version of the system as well which will make the system easier to setup. -We plan to release a docker version of the system as well which will make the system easier to setup. - ## Wiki Wiki section of this repository provides coding practices and examples to get started with EzPC. From 0c021e03229ea00e5be4413d7e9a228a1e242790 Mon Sep 17 00:00:00 2001 From: Pratik Bhatu Date: Wed, 20 Jan 2021 10:53:36 +0530 Subject: [PATCH 39/72] add setup section in readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 0b7d35d8..92dff75f 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ Each one of the above is independent and usable in their own right and more info With these components in place, we are able to run for the first time secure inference on the [ImageNet dataset]([http://www.image-net.org) with the pre-trained models of the following deep neural nets: ResNet-50, DenseNet-121 and SqueezeNet for ImageNet. +## Setup For setup instructions, please refer to each of the components' readme. Alternatively you can use the **setup_env_and_build.sh** script. It installs dependencies and builds each component. It also creates a virtual environment in a *mpc_venv* folder with all the required packages. From 0321c081235347a1f225e53d29f8eab295f84e60 Mon Sep 17 00:00:00 2001 From: Pratik Bhatu Date: Wed, 20 Jan 2021 11:56:09 +0530 Subject: [PATCH 40/72] [README] Instructions to run the networks using scripts. --- Athos/README.md | 100 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 82 insertions(+), 18 deletions(-) diff --git a/Athos/README.md b/Athos/README.md index c3e0e66b..aba9b63b 100644 --- a/Athos/README.md +++ b/Athos/README.md @@ -3,37 +3,89 @@ This folder contains the code for Athos - an end-to-end compiler from TensorFlow # Requirements/Setup Below we list some of the packages required to get Athos up and running. This is a non-exhaustive list and we only mention the packages on which the system has been tested so far. -- axel: download manager used to download pre-trained models `sudo apt-get install axel`. -- python3.6 -- TensorFlow 1.11 +- python3.7 +- TensorFlow 1.15 - Numpy Athos also makes use of the EzPC compiler internally (please check `../EzPC/README.md` for corresponding dependencies). -# Directory structure -The codebase is organized as follows: -- `HelperScripts`: This folder contains numerous helper scripts which help from automated setup of ImageNet/CIFAR10 dataset to finding accuracy from output files. Please refer to each of the scripts for further instructions on how to use them. -- `Networks`: This folder contains the code in TensorFlow of the various benchmarks/networks we run in CrypTFlow. Among other networks, it includes code for ResNet, DenseNet, SqueezeNet for ImageNet dataset, SqueezeNet for CIFAR10 dataset, Lenet, Logistic Regression, and a chest x-ray demo network. -- `SeeDot`: This contains code for SeeDot, a high-level intermediate language on which Athos performs various optimizations before compiling to MPC protocols. -- `TFCompiler`: This contains python modules which are called from the TensorFlow code for the dumping of TensorFlow metadata (required by Athos for compilation to MPC protocols). -- `TFEzPCLibrary`: This contains library code written in EzPC for the TensorFlow nodes required during compilation. -- `CompileTF.sh`: The Athos compilation script. Try `./CompileTF.sh --help` for options. -- `CompileTFGraph.sh`: The Athos compilation script for protobuf models. Try `./CompileTFGraph.sh --help` for options. -- `Paths.config`: This can be used to override the default folders for EzPC and Porthos. -- `CompilerScripts`: This folder contains scripts used for processing and compiling dumped models. - # Usage -Here we provide an example on how to use Athos to compile TensorFlow based ResNet-50 code to Porthos semi-honest 3PC protocol and subsequently run it. The relevant TensorFlow code for ResNet-50 can be found in `./Networks/ResNet/ResNet_main.py`. +Please source the virtual environment in `mpc_venv` if you used the `setup_env_and_build.sh` script to setup and build. +## Automatic +The `CompileSampleNetworks.py` script can compile and optionally run models in the Networks directory like ResNet-50, DenseNet, SqueezeNet, etc.. +To compile and run ResNet with the Porthos semi-honest 3PC protocol we do: + +```python CompileSampleNetworks.py --config Networks/sample_network.config``` + +The script takes a config file as input. The contents of the config are: +``` +{ + "network_name":"ResNet", + "target":"PORTHOS", + "run_in_tmux": true +} +``` +- *network_name*: Can be any network in the Networks directory. +- *target*: This is the secure protocol the model will run in. The possible values are: + - **PORTHOS**: The semi-honest 3PC protocol. + - **PORTHOS2PC**: The semi-honest 2PC protocol in SCI. + - **CPP**: A non-secure debug backend which outputs plain C++ to test for correctness. +- *run_in_tmux*: If true, the script spawns a tmux session to run the network. There is a terminal pane for each party. +You can modify the config file according to which network and backend you want to compile for. See ```python CompileSampleNetworks.py --help``` for more information about the parameters of the config file. + +**Output:** +After connecting to the session with ```tmux a -t ResNet``` you should see output similar to the following after the computation is complete. Number will vary based on the specs of your machine. + +| | | +|-|-| +|-------------------------------------------------------
**ResNet results [Client]**
-------------------------------------------------------
Model outputs:
MPC PORTHOS (3PC) output: 249
Tensorflow output: 249

Execution summary for Client:
Communication for execution, P0: 2377.13MB (sent) 1825.32MB (recv)
Peak Memory Usage: 432156 KB (.41GB)
Total time taken: 70.50 seconds
Total work time: 68.57 seconds (97.27%)
Time spent waiting: 1.91 seconds (2.72%)
Time taken by tensorflow: 0.31 seconds | +|-------------------------------------------------------
**ResNet results [Server]**
-------------------------------------------------------
Execution summary for Server:
Communication for execution, P1: 2377.13MB (sent) 2155.96MB (recv)
Peak Memory Usage: 432428 KB (.41GB)
Total time taken: 70.60 seconds
Total work time: 68.77 seconds (97.42%)
Time spent waiting: 1.81 seconds (2.57%) |-------------------------------------------------------
**ResNet results [Helper]**
-------------------------------------------------------
Execution summary for Helper:
Communication for execution, P2: 2113.81MB (sent) 2886.8MB (recv)
Peak Memory Usage: 427508 KB (.40GB)
Total time taken: 70.53 seconds
Total work time: 64.06 seconds (90.83%)
Time spent waiting: 6.46 seconds (9.16%) + +#### Running manually (non-tmux-mode) + +If run_in_tmux is false and you want to run the network manually, you will need the following files that are generated by the script in the `Networks/ResNet` directory: +- *ResNet_PORTHOS.out*: binary of the compiled network. +- *model_input_scale_12.inp*: image input to the model. +- *model_weights_scale_12.inp*: model weights. + +Running it manually will not give you a neat summary as shown above but will print the computed output on the client terminal. + +To run the network in **3PC mode (PORTHOS)**, open 3 terminals and do the following for each party: + +- Party 0 [Client]: + + ``` ./Networks/ResNet/ResNet_PORTHOS.out 0 ../Porthos/files/addresses ../Porthos/files/keys < model_input_scale_12.inp" ``` +- Party 1 [Server]: + + ``` ./Networks/ResNet/ResNet_PORTHOS.out 1 ../Porthos/files/addresses ../Porthos/files/keys < model_weights_scale_12.inp" ``` +- Party 2 [Helper]: + + ``` ./Networks/ResNet/ResNet_PORTHOS.out 1 ../Porthos/files/addresses ../Porthos/files/keys" ``` + +To run the network in **2PC mode (SCI)**, open 2 terminals and do the following for each party: + +- Party 0 [Server]: + + ``` ./Networks/ResNet/ResNet_PORTHOS2PC_OT.out r=1 p=12345 < model_weights_scale_12.inp" ``` +- Party 1 [Client]: + + ``` ./Networks/ResNet/ResNet_PORTHOS2PC_OT.out r=2 ip=127.0.0.1 p=12345 < model_input_scale_12.inp" ``` + +To run the network in **CPP mode (1PC-debug-non-secure)**, open a terminals and do the following: +- ``` ./Networks/ResNet/ResNet_CPP.out < <(cat model_input_scale_12.inp model_weights_scale_12.inp) ``` + +## Manually +To better understand what the `CompileSampleNetworks.py` script is doing under the hood, we can step through each step manually. Here we provide an example on how to use Athos to compile TensorFlow based ResNet-50 code to Porthos semi-honest 3PC protocol and subsequently run it. The relevant TensorFlow code for ResNet-50 can be found in `./Networks/ResNet/ResNet_main.py`. - Refer to `./Networks/ResNet/README.md` for instructions on how to download and extract the ResNet-50 pretrained model from the official TensorFlow model page. - `cd ./Networks/ResNet && python3 ResNet_main.py --runPrediction True --scalingFac 12 --saveImgAndWtData True && cd -` Runs the ResNet-50 code written in TensorFlow to dump the metadata which is required by Athos for further compilation. -This command execution should result in 2 files which will be used for further compilation - `./Networks/ResNet/graphDef.mtdata` and `./Networks/ResNet/sizeInfo.mtdata`. In addition, the image and the model are also saved in fixed-point format, which can be later input into the compiled code - `./Networks/ResNet/ResNet_img.inp` which contains the image and `./Networks/ResNet/ResNet_weights.inp` which contains the model. +This command execution should result in 2 files which will be used for further compilation - `./Networks/ResNet/graphDef.mtdata` and `./Networks/ResNet/sizeInfo.mtdata`. In addition, the image and the model are also saved in fixed-point format, which can be later input into the compiled code - `./Networks/ResNet/model_input_scale_12.inp` which contains the image and `./Networks/ResNet/model_weights_scale_12.inp` which contains the model weights. - The next step is to perform the compilation itself. The compilation script internally makes use of the `ezpc` executable. So, before continuing please ensure that you have built `ezpc` (please check the `../EzPC/README.md` for further instructions on that). - Once EzPC has been built, run this to compile the model to Porthos - `./CompileTF.sh -b 64 -s 12 -t PORTHOS -f ./Networks/ResNet/ResNet_main.py`. This should result in creation of the file - `./Networks/ResNet/ResNet_main_64_porthos0.cpp`. - `cp ./Networks/ResNet/ResNet_main_64_porthos0.cpp ../Porthos/src/main.cpp` Copy the compiled file to Porthos. - `cd ../Porthos && make clean && make -j` -- Finally run the 3 parties. Open 3 terminals and run the following in each for the 3 parties. +- Finally run the 3 parties. Go to the porthos directory and open 3 terminals and run the following in each for the 3 parties. `./party0.sh < ../Athos/Networks/ResNet/ResNet_img.inp` , `./party1.sh < ../Athos/Networks/ResNet/ResNet_weights.inp` , `./party2.sh`. @@ -41,6 +93,18 @@ Once the above runs, the final answer for prediction should appear in the output Instructions on how to run the particular TensorFlow model in `./Networks` can vary. Please refer to the appropriate readme in each model folder to get more insights. But once that is done, the further compilation commands are the same. +# Directory structure +The codebase is organized as follows: +- `HelperScripts`: This folder contains numerous helper scripts which help from automated setup of ImageNet/CIFAR10 dataset to finding accuracy from output files. Please refer to each of the scripts for further instructions on how to use them. +- `Networks`: This folder contains the code in TensorFlow of the various benchmarks/networks we run in CrypTFlow. Among other networks, it includes code for ResNet, DenseNet, SqueezeNet for ImageNet dataset, SqueezeNet for CIFAR10 dataset, Lenet, Logistic Regression, and a chest x-ray demo network. +- `SeeDot`: This contains code for SeeDot, a high-level intermediate language on which Athos performs various optimizations before compiling to MPC protocols. +- `TFCompiler`: This contains python modules which are called from the TensorFlow code for the dumping of TensorFlow metadata (required by Athos for compilation to MPC protocols). +- `TFEzPCLibrary`: This contains library code written in EzPC for the TensorFlow nodes required during compilation. +- `CompileTF.sh`: The Athos compilation script. Try `./CompileTF.sh --help` for options. +- `CompileTFGraph.py`: The Athos compilation script for protobuf models. Try `python CompileTFGraph.py --help` for options. +- `Paths.config`: This can be used to override the default folders for EzPC and Porthos. +- `CompilerScripts`: This folder contains scripts used for processing and compiling dumped models. + # Preprocessing images and running inference on ImageNet validation dataset - First setup the ImageNet validation dataset using the script provided in `./HelperScripts/Prepare_ImageNet_Val.sh`. This sets up the ImageNet validation dataset in the folder - `./HelperScripts/ImageNet_ValData`. - Each of the network folders - `./Networks/ResNet`, `./Networks/DenseNet` and `./Networks/SqueezeNetImgNet` is provided with these folders: From 5763338ec01f08eed0f50078edf9a547732a2be8 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 20 Jan 2021 17:25:31 +0530 Subject: [PATCH 41/72] Change default save_weights to true --- Athos/CompileTFGraph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index e9379f05..e8e3fa9a 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -57,7 +57,7 @@ def parse_args(): //--------------------------- Optional options --------------------------- "scale":10, // Scaling factor to compile for. DEFAULT=12. "bitlength":64, // Bit length to compile for. DEFAULT=12. - "save_weights" : true, // Save model scaled weights in fixed point. DEFAULT=false. + "save_weights" : true, // Save model scaled weights in fixed point. DEFAULT=true. "input_tensors":{ // Name and shape of the input tensors "actual_input_1":"224,244,3", // for the model. Not required if the @@ -89,7 +89,7 @@ def generate_code(params, debug=False): scale = 12 if params["scale"] is None else params["scale"] bitlength = 64 if params["bitlength"] is None else params["bitlength"] target = params["target"] - save_weights = False if params["save_weights"] is None else params["save_weights"] + save_weights = True if params["save_weights"] is None else params["save_weights"] disable_all_hlil_opts = ( False if params["disable_all_hlil_opts"] is None From 80eb0d3f1dcdb9680e789584e7eed1d695a75d30 Mon Sep 17 00:00:00 2001 From: Pratik Bhatu Date: Wed, 20 Jan 2021 23:43:09 +0530 Subject: [PATCH 42/72] Add instructions to compile any tf model --- Athos/README.md | 51 ++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/Athos/README.md b/Athos/README.md index aba9b63b..387b0c96 100644 --- a/Athos/README.md +++ b/Athos/README.md @@ -1,17 +1,31 @@ +- [Introduction](#introduction) +- [Requirements/Setup](#requirementssetup) +- [Usage](#usage) + * [Compiling and Running Models in Networks Directory Automatically](#compiling-and-running-models-in-networks-directory-automatically) + - [Running manually (non-tmux-mode)](#running-manually-non-tmux-mode) + * [Compiling and Running Models in Networks Directory Manually](#compiling-and-running-models-in-networks-directory-manually) + * [Compiling your own tensorflow model](#compiling-your-own-tensorflow-model) +- [Directory structure](#directory-structure) +- [Preprocessing images and running inference on ImageNet validation dataset](#preprocessing-images-and-running-inference-on-imagenet-validation-dataset) + # Introduction This folder contains the code for Athos - an end-to-end compiler from TensorFlow to a variety of secure computation protocols. # Requirements/Setup -Below we list some of the packages required to get Athos up and running. This is a non-exhaustive list and we only mention the packages on which the system has been tested so far. +If you used the `setup_env_and_build.sh` script the below would already have been installed in the `mpc_venv` environment. We require the below packages to run Athos. - python3.7 - TensorFlow 1.15 - Numpy +- pytest, pytest-cov (For running tests) Athos also makes use of the EzPC compiler internally (please check `../EzPC/README.md` for corresponding dependencies). # Usage Please source the virtual environment in `mpc_venv` if you used the `setup_env_and_build.sh` script to setup and build. -## Automatic + +`source mpc_venv/bin/activate` + +## Compiling and Running Models in Networks Directory Automatically The `CompileSampleNetworks.py` script can compile and optionally run models in the Networks directory like ResNet-50, DenseNet, SqueezeNet, etc.. To compile and run ResNet with the Porthos semi-honest 3PC protocol we do: @@ -34,7 +48,7 @@ The script takes a config file as input. The contents of the config are: You can modify the config file according to which network and backend you want to compile for. See ```python CompileSampleNetworks.py --help``` for more information about the parameters of the config file. **Output:** -After connecting to the session with ```tmux a -t ResNet``` you should see output similar to the following after the computation is complete. Number will vary based on the specs of your machine. +After connecting to the session with ```tmux a -t ResNet``` you should see output similar to the following after the computation is complete. Numbers will vary based on the specs of your machine. | | | |-|-| @@ -44,8 +58,8 @@ After connecting to the session with ```tmux a -t ResNet``` you should see outpu #### Running manually (non-tmux-mode) If run_in_tmux is false and you want to run the network manually, you will need the following files that are generated by the script in the `Networks/ResNet` directory: -- *ResNet_PORTHOS.out*: binary of the compiled network. -- *model_input_scale_12.inp*: image input to the model. +- *ResNet_PORTHOS.out*: binary of the compiled network. +- *model_input_scale_12.inp*: image input to the model. - *model_weights_scale_12.inp*: model weights. Running it manually will not give you a neat summary as shown above but will print the computed output on the client terminal. @@ -71,10 +85,10 @@ To run the network in **2PC mode (SCI)**, open 2 terminals and do the following ``` ./Networks/ResNet/ResNet_PORTHOS2PC_OT.out r=2 ip=127.0.0.1 p=12345 < model_input_scale_12.inp" ``` -To run the network in **CPP mode (1PC-debug-non-secure)**, open a terminals and do the following: +To run the network in **CPP mode (1PC-debug-non-secure)**, open a terminal and do the following: - ``` ./Networks/ResNet/ResNet_CPP.out < <(cat model_input_scale_12.inp model_weights_scale_12.inp) ``` -## Manually +## Compiling and Running Models in Networks Directory Manually To better understand what the `CompileSampleNetworks.py` script is doing under the hood, we can step through each step manually. Here we provide an example on how to use Athos to compile TensorFlow based ResNet-50 code to Porthos semi-honest 3PC protocol and subsequently run it. The relevant TensorFlow code for ResNet-50 can be found in `./Networks/ResNet/ResNet_main.py`. - Refer to `./Networks/ResNet/README.md` for instructions on how to download and extract the ResNet-50 pretrained model from the official TensorFlow model page. - `cd ./Networks/ResNet && python3 ResNet_main.py --runPrediction True --scalingFac 12 --saveImgAndWtData True && cd -` @@ -93,6 +107,29 @@ Once the above runs, the final answer for prediction should appear in the output Instructions on how to run the particular TensorFlow model in `./Networks` can vary. Please refer to the appropriate readme in each model folder to get more insights. But once that is done, the further compilation commands are the same. +## Compiling your own tensorflow model +The `CompileTFGraph.py` script can compile tensorflow models (v1.15). You can dump your tensorflow model as a frozen graph. Run [convert_variables_to_constants](https://www.tensorflow.org/api_docs/python/tf/compat/v1/graph_util/convert_variables_to_constants) on your model graph and then dump the output graph_def as a protobuf (see `dump_graph_def_pb` in `CompilerScripts/tf_graph_io.py`). Once you have the model.pb file simply do: +``` +python CompileTFGraph.py --config model.config +``` +See `python CompileTFGraph.py --help` for additional details on the model.config parameters. A sample config could be: +``` +{ + "model_name": "model.pb", + "output_tensors": [ "output1" ], + "target": "PORTHOS" +} +``` +You will see the output messages of the compiler and a `model_PORTHOS.out` binary will be generated. You will also see this in the output: +``` +Model compilation done. +Dumping model weights in model_input_weights_fixedpt_scale_12.inp . +These are to be used as input for party which owns the model. +``` +Use the `model_input_weights_fixedpt_scale_12.inp` file as input for the server party. For model input you can create random input using `CompilerScripts/create_tf_input.py` or pass your input as a numpy array to the `dumpImageDataInt` function in `TFCompiler/DumpTFMtData.py`. For both scripts you need to pass the scaling factor for conversion of floating point to fixed point (we use 12 for ResNet). + + + # Directory structure The codebase is organized as follows: - `HelperScripts`: This folder contains numerous helper scripts which help from automated setup of ImageNet/CIFAR10 dataset to finding accuracy from output files. Please refer to each of the scripts for further instructions on how to use them. From 503276c46315b589aa8d82bcf8ada87553d74fb6 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Thu, 21 Jan 2021 21:11:34 +0530 Subject: [PATCH 43/72] Implement Depthwise Separable Convolution Add support for tf.nn.depthwise_conv2d --- Athos/CompileTFGraph.py | 2 +- Athos/CompilerScripts/compile_tf.py | 3 +- Athos/SeeDot/Type.py | 2 +- Athos/TFCompiler/ProcessTFGraph.py | 18 +++++++++-- Athos/TFCompiler/TFNodesAST.py | 20 +++++++++++- Athos/tests/tf/unittests/test_convolution.py | 34 ++++++++++++++++++-- 6 files changed, 70 insertions(+), 9 deletions(-) diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index e8e3fa9a..86325150 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -207,7 +207,7 @@ def generate_code(params, debug=False): os.system("rm {}".format(ezpc_file_name)) output_file = os.path.join(model_abs_dir, output_name) - print("Compiling generated code to {target} target".format(target)) + print("Compiling generated code to {target} target".format(target=target)) if target == "PORTHOS2PC": program_name = model_base_name + "_" + target + "_" + backend + ".out" else: diff --git a/Athos/CompilerScripts/compile_tf.py b/Athos/CompilerScripts/compile_tf.py index 2e747f83..c078918f 100644 --- a/Athos/CompilerScripts/compile_tf.py +++ b/Athos/CompilerScripts/compile_tf.py @@ -93,8 +93,7 @@ def infer_input_info(graph): inp_shape = [] else: inp_shape = input_t.shape.as_list() - assert None not in inp_shape, "Placeholder node " + i.name + "has unknown" - +" shape. Please specify name and shape in config" + assert None not in inp_shape, "Placeholder node " + i.name + "has unknown shape. Please specify name and shape in config" input_t_info[i.name] = inp_shape return input_t_info diff --git a/Athos/SeeDot/Type.py b/Athos/SeeDot/Type.py index 33f880bb..7c37c3f5 100644 --- a/Athos/SeeDot/Type.py +++ b/Athos/SeeDot/Type.py @@ -370,7 +370,7 @@ def visitBopConv(self, node:AST.BOp, eType:Type, fType:Type, args=None): assert(FH == node.options[AST.PaddingKeysDict.FH]) assert(FW == node.options[AST.PaddingKeysDict.FW]) - assert(CI1*group == CI) + assert CI1*group == CI, "FCI={} group={} CI={}".format(CI1, group, CI) zPadHLeft = node.options[AST.PaddingKeysDict.zPadHLeft] zPadHRight = node.options[AST.PaddingKeysDict.zPadHRight] zPadWLeft = node.options[AST.PaddingKeysDict.zPadWLeft] diff --git a/Athos/TFCompiler/ProcessTFGraph.py b/Athos/TFCompiler/ProcessTFGraph.py index 64e84d2d..cd81acba 100644 --- a/Athos/TFCompiler/ProcessTFGraph.py +++ b/Athos/TFCompiler/ProcessTFGraph.py @@ -115,7 +115,8 @@ def prefixAllPlaceHolderNodes(graph): # List of Optimisations # 1. Split squared difference into (a-b)*(a-b) -def simplifyGraph(graph): +# 2. Reshape filter of depth separable convolution to convert it to a grouped convolution +def simplifyGraph(graph, sizeInfo): allNodes = graph.getAllNodesRef() nodesMap = graph.getAllNodes() newNodes = [] @@ -134,6 +135,19 @@ def simplifyGraph(graph): nodesMap[mul.getName()] = mul inputsFixup[curNode.getName()] = mul.getName() nodesMap.pop(curNode.getName()) + elif (curNode.getOp() == "DepthwiseConv2dNative"): + filter_shape = sizeInfo[inputs[1]] + in_channels = filter_shape[2] + channel_multiplier = filter_shape[3] + output_channels = in_channels * channel_multiplier + # new filter shape = [FH, FW, 1, CI*CM] + new_filter_shape = filter_shape[0:2] + [1, output_channels] + reshape = Graph.Node("Reshape", [inputs[1]], curNode.getName() + "__reshape") + newNodes.append(reshape) + newNodes.append(curNode) + nodesMap[reshape.getName()] = reshape + inputs[1] = reshape.getName() + sizeInfo[reshape.getName()] = new_filter_shape else: newNodes.append(curNode) graph.setNodesList(newNodes) @@ -155,7 +169,7 @@ def process_tf_graph(filename): sizeInfo = readSizeInfo(sizeInfoFileName) # Tensorflow graph level optimisations - simplifyGraph(graph) + simplifyGraph(graph, sizeInfo) # Place all PlaceHolder nodes together at the beginning prefixAllPlaceHolderNodes(graph) diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index fe934564..e2faff7c 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -292,7 +292,7 @@ def Fill(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : di def Reshape(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) + #assert(len(inputsRef) == 2) return (None, { curNode.getName() : AST.Reshape(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), extraNodeInfoDict[curNode.getName()][0], None)}) def helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD = None, FD = None, strideD = None): @@ -379,6 +379,24 @@ def Conv2D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), options)}) + # A depthwise separable convolution is equivalent to a grouped convolution + # with no. of groups = the no. of input channels (G=CI) + # This however requires a reshape of the filter. + # Regular filter is [ FH, FW, CI, CM (channel_multiplier)] + # Doing depthwise conv results in CO = CI * CM + # Grouped conv expects [FH, FW, CI/G, (CO/G)*G] + # So we reshape to [FH, FW, 1, CI * CM] + def DepthwiseConv2dNative(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): + inputsRef = curNode.getInputsRef() + assert(len(inputsRef)==2) + # Reshape of filter is done in simplifyGraph + img_shape = extraNodeInfoDict[inputsRef[0]][0] + in_channels = img_shape[3] #NHWC + groups = in_channels + _ , nodeToAST = TFNodesAST.Conv2D(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict) + nodeToAST[curNode.getName()].options[AST.PaddingKeysDict.group] = groups + return (None, nodeToAST) + def Conv3D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef)==2) diff --git a/Athos/tests/tf/unittests/test_convolution.py b/Athos/tests/tf/unittests/test_convolution.py index 0301deaa..8c98ce7f 100644 --- a/Athos/tests/tf/unittests/test_convolution.py +++ b/Athos/tests/tf/unittests/test_convolution.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import tensorflow as tf import numpy as np @@ -65,6 +65,36 @@ def test_conv(test_dir, backend, tfOp, a_shape, kernel_shape, strides, padding, return +@pytest.mark.parametrize( + "tfOp, a_shape, kernel_shape, strides, padding", + [ + (tf.nn.depthwise_conv2d, [1, 5, 5, 1], [2, 2, 1, 3], [1, 1, 1, 1], "VALID"), + (tf.nn.depthwise_conv2d, [1, 5, 5, 1], [2, 2, 1, 3], [1, 1, 1, 1], "SAME"), + (tf.nn.depthwise_conv2d, [1, 5, 5, 3], [2, 2, 3, 2], [1, 1, 1, 1], "VALID"), + ], +) +@pytest.mark.parametrize("dtype", [np.single]) +def test_depthwise_conv( + test_dir, backend, tfOp, a_shape, kernel_shape, strides, padding, dtype +): + graph = tf.Graph() + a_inp = dtype(np.random.randn(*a_shape)) + kernel_inp = dtype(np.random.randn(*kernel_shape)) + with graph.as_default(): + a = tf.compat.v1.placeholder(tf.as_dtype(dtype), shape=a_inp.shape, name="a") + filters = tf.constant(kernel_inp, name="filter") + output = tfOp(a, filters, strides, padding, name="output") + with tf.compat.v1.Session(graph=graph) as sess: + expected_output = sess.run(output, feed_dict={a: a_inp}) + + config = Config(backend).add_input(a).add_output(output) + config.config["scale"] = 12 + compiler = Compiler(graph, config, test_dir) + mpc_output = compiler.compile_and_run([a_inp]) + assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) + return + + @pytest.mark.parametrize( "tfOp, a_shape, kernel_shape, output_shape, strides, padding", [ From 3aa5d682d2247f76cff46e04f5a07d0d9fd90d2e Mon Sep 17 00:00:00 2001 From: Bhatu Date: Fri, 22 Jan 2021 00:55:46 +0530 Subject: [PATCH 44/72] Ensure topological sorting in graphdef. We expect the input tensorflow graph_def to be topologically sorted but have encountered models where this does not hold true. So we manually sort it such that we process inputs before outputs. --- Athos/CompileSampleNetworks.py | 28 +++++++++--------- Athos/CompileTFGraph.py | 28 +++++++++--------- Athos/CompilerScripts/compile_tf.py | 2 +- Athos/TFCompiler/ProcessTFGraph.py | 46 ++++++++++++++++++++++++++++- 4 files changed, 74 insertions(+), 30 deletions(-) diff --git a/Athos/CompileSampleNetworks.py b/Athos/CompileSampleNetworks.py index 5c474ef7..cec17087 100644 --- a/Athos/CompileSampleNetworks.py +++ b/Athos/CompileSampleNetworks.py @@ -145,10 +145,10 @@ def generate_code(params, debug=False): ezpc_abs_path = os.path.join(model_abs_dir, ezpc_file_name) seedot_args = "" - seedot_args += "--astFile {}/astOutput.pkl --consSF {} ".format( + seedot_args += "--astFile \"{}/astOutput.pkl\" --consSF {} ".format( model_abs_dir, scale ) - seedot_args += "--bitlen {} --outputFileName {} ".format(bitlength, ezpc_abs_path) + seedot_args += "--bitlen {} --outputFileName \"{}\" ".format(bitlength, ezpc_abs_path) seedot_args += "--disableAllOpti {} ".format(disable_all_hlil_opts) seedot_args += "--disableRMO {} ".format(disable_relu_maxpool_opts) seedot_args += "--disableLivenessOpti {} ".format(disable_garbage_collection) @@ -181,15 +181,15 @@ def generate_code(params, debug=False): post = "" temp = os.path.join(model_abs_dir, "temp.ezpc") os.system( - "cat {pre} {common} {post} {ezpc}> {temp}".format( + "cat \"{pre}\" \"{common}\" \"{post}\" \"{ezpc}\"> \"{temp}\"".format( pre=pre, common=common, post=post, ezpc=ezpc_abs_path, temp=temp ) ) - os.system("mv {temp} {ezpc}".format(temp=temp, ezpc=ezpc_abs_path)) + os.system("mv \"{temp}\" \"{ezpc}\"".format(temp=temp, ezpc=ezpc_abs_path)) ezpc_dir = os.path.join(athos_dir, "../EzPC/EzPC/") # Copy generated code to the ezpc directory - os.system("cp {ezpc} {ezpc_dir}".format(ezpc=ezpc_abs_path, ezpc_dir=ezpc_dir)) + os.system("cp \"{ezpc}\" \"{ezpc_dir}\"".format(ezpc=ezpc_abs_path, ezpc_dir=ezpc_dir)) os.chdir(ezpc_dir) ezpc_args = "" ezpc_args += "--bitlen {bl} --codegen {target} --disable-tac ".format( @@ -205,12 +205,12 @@ def generate_code(params, debug=False): ezpc_args += "--sf {} ".format(scale) os.system( - "eval `opam config env`; ./ezpc.sh {} ".format(ezpc_file_name) + ezpc_args + "eval `opam config env`; ./ezpc.sh \"{}\" ".format(ezpc_file_name) + ezpc_args ) os.system( - "mv {output} {model_dir} ".format(output=output_name, model_dir=model_abs_dir) + "mv \"{output}\" \"{model_dir}\" ".format(output=output_name, model_dir=model_abs_dir) ) - os.system("rm {}".format(ezpc_file_name)) + os.system("rm \"{}\"".format(ezpc_file_name)) output_file = os.path.join(model_abs_dir, output_name) print( @@ -235,7 +235,7 @@ def generate_code(params, debug=False): if target in ["CPP", "CPPRING"]: os.system( - "g++ {opt_flag} -w {file} -o {output}".format( + "g++ {opt_flag} -w \"{file}\" -o \"{output}\"".format( file=output_file, output=program_path, opt_flag=opt_flag ) ) @@ -245,9 +245,9 @@ def generate_code(params, debug=False): if os.path.exists(porthos_lib): os.system( """g++ {opt_flag} -fopenmp -pthread -w -march=native -msse4.1 -maes -mpclmul \ - -mrdseed -fpermissive -fpic -std=c++17 -L {porthos_lib} -I {porthos_headers} {file} \ + -mrdseed -fpermissive -fpic -std=c++17 -L \"{porthos_lib}\" -I \"{porthos_headers}\" \"{file}\" \ -lPorthos-Protocols -lssl -lcrypto -lrt -lboost_system \ - -o {output}""".format( + -o \"{output}\"""".format( porthos_lib=porthos_lib, porthos_headers=porthos_src, file=output_file, @@ -268,9 +268,9 @@ def generate_code(params, debug=False): if os.path.exists(sci_lib): os.system( """g++ {opt_flag} -fpermissive -pthread -w -maes -msse4.1 -mavx -mavx2 -mrdseed \ - -faligned-new -std=c++17 -fopenmp -I {eigen} -I {sci_src} {file} \ - -L {sci_lib} -lSCI-LinearHE -L {seal} -lseal -lssl -lcrypto \ - -o {output}""".format( + -faligned-new -std=c++17 -fopenmp -I \"{eigen}\" -I \"{sci_src}\" \"{file}\" \ + -L \"{sci_lib}\" -lSCI-LinearHE -L \"{seal}\" -lseal -lssl -lcrypto \ + -o \"{output}\"""".format( eigen=eigen_path, sci_src=sci_src, file=output_file, diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index 86325150..98b8dd16 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -139,10 +139,10 @@ def generate_code(params, debug=False): ezpc_abs_path = os.path.join(model_abs_dir, ezpc_file_name) seedot_args = "" - seedot_args += "--astFile {}/astOutput.pkl --consSF {} ".format( + seedot_args += "--astFile \"{}/astOutput.pkl\" --consSF {} ".format( model_abs_dir, scale ) - seedot_args += "--bitlen {} --outputFileName {} ".format(bitlength, ezpc_abs_path) + seedot_args += "--bitlen {} --outputFileName \"{}\" ".format(bitlength, ezpc_abs_path) seedot_args += "--disableAllOpti {} ".format(disable_all_hlil_opts) seedot_args += "--disableRMO {} ".format(disable_relu_maxpool_opts) seedot_args += "--disableLivenessOpti {} ".format(disable_garbage_collection) @@ -175,15 +175,15 @@ def generate_code(params, debug=False): post = "" temp = os.path.join(model_abs_dir, "temp.ezpc") os.system( - "cat {pre} {common} {post} {ezpc}> {temp}".format( + "cat \"{pre}\" \"{common}\" \"{post}\" \"{ezpc}\"> \"{temp}\"".format( pre=pre, common=common, post=post, ezpc=ezpc_abs_path, temp=temp ) ) - os.system("mv {temp} {ezpc}".format(temp=temp, ezpc=ezpc_abs_path)) + os.system("mv \"{temp}\" \"{ezpc}\"".format(temp=temp, ezpc=ezpc_abs_path)) ezpc_dir = os.path.join(athos_dir, "../EzPC/EzPC/") # Copy generated code to the ezpc directory - os.system("cp {ezpc} {ezpc_dir}".format(ezpc=ezpc_abs_path, ezpc_dir=ezpc_dir)) + os.system("cp \"{ezpc}\" \"{ezpc_dir}\"".format(ezpc=ezpc_abs_path, ezpc_dir=ezpc_dir)) os.chdir(ezpc_dir) ezpc_args = "" ezpc_args += "--bitlen {bl} --codegen {target} --disable-tac ".format( @@ -199,12 +199,12 @@ def generate_code(params, debug=False): ezpc_args += "--sf {} ".format(scale) os.system( - "eval `opam config env`; ./ezpc.sh {} ".format(ezpc_file_name) + ezpc_args + "eval `opam config env`; ./ezpc.sh \"{}\" ".format(ezpc_file_name) + ezpc_args ) os.system( - "mv {output} {model_dir} ".format(output=output_name, model_dir=model_abs_dir) + "mv \"{output}\" \"{model_dir}\" ".format(output=output_name, model_dir=model_abs_dir) ) - os.system("rm {}".format(ezpc_file_name)) + os.system("rm \"{}\"".format(ezpc_file_name)) output_file = os.path.join(model_abs_dir, output_name) print("Compiling generated code to {target} target".format(target=target)) @@ -220,7 +220,7 @@ def generate_code(params, debug=False): opt_flag = "-O3" if target in ["CPP", "CPPRING"]: os.system( - "g++ {opt_flag} -w {file} -o {output}".format( + "g++ {opt_flag} -w \"{file}\" -o \"{output}\"".format( file=output_file, output=program_path, opt_flag=opt_flag ) ) @@ -230,9 +230,9 @@ def generate_code(params, debug=False): if os.path.exists(porthos_lib): os.system( """g++ {opt_flag} -fopenmp -pthread -w -march=native -msse4.1 -maes -mpclmul \ - -mrdseed -fpermissive -fpic -std=c++17 -L {porthos_lib} -I {porthos_headers} {file} \ + -mrdseed -fpermissive -fpic -std=c++17 -L \"{porthos_lib}\" -I \"{porthos_headers}\" \"{file}\" \ -lPorthos-Protocols -lssl -lcrypto -lrt -lboost_system \ - -o {output}""".format( + -o \"{output}\"""".format( porthos_lib=porthos_lib, porthos_headers=porthos_src, file=output_file, @@ -253,9 +253,9 @@ def generate_code(params, debug=False): if os.path.exists(sci_lib): os.system( """g++ {opt_flag} -fpermissive -pthread -w -maes -msse4.1 -mavx -mavx2 -mrdseed \ - -faligned-new -std=c++17 -fopenmp -I {eigen} -I {sci_src} {file} \ - -L {sci_lib} -lSCI-LinearHE -L {seal} -lseal -lssl -lcrypto \ - -o {output}""".format( + -faligned-new -std=c++17 -fopenmp -I \"{eigen}\" -I \"{sci_src}\" \"{file}\" \ + -L \"{sci_lib}\" -lSCI-LinearHE -L \"{seal}\" -lseal -lssl -lcrypto \ + -o \"{output}\"""".format( eigen=eigen_path, sci_src=sci_src, file=output_file, diff --git a/Athos/CompilerScripts/compile_tf.py b/Athos/CompilerScripts/compile_tf.py index c078918f..11e07ec8 100644 --- a/Athos/CompilerScripts/compile_tf.py +++ b/Athos/CompilerScripts/compile_tf.py @@ -93,7 +93,7 @@ def infer_input_info(graph): inp_shape = [] else: inp_shape = input_t.shape.as_list() - assert None not in inp_shape, "Placeholder node " + i.name + "has unknown shape. Please specify name and shape in config" + assert None not in inp_shape, "Placeholder node " + i.name + " has unknown shape. Please specify name and shape in config" input_t_info[i.name] = inp_shape return input_t_info diff --git a/Athos/TFCompiler/ProcessTFGraph.py b/Athos/TFCompiler/ProcessTFGraph.py index cd81acba..94fa359b 100644 --- a/Athos/TFCompiler/ProcessTFGraph.py +++ b/Athos/TFCompiler/ProcessTFGraph.py @@ -51,7 +51,7 @@ def generateIRCode(graph, extraInfoDict): mtdAST = MtdAST() for curNode in graph.getAllNodesRef(): for curInp in curNode.getInputsRef(): - assert(curInp in dictNodeNameToOutVarStr), "input={} expected as input but not yet processed".format(curInp) #Consequence of topological sorting of the TF graph + assert(curInp in dictNodeNameToOutVarStr), "input={} expected as input for node={} but not yet processed".format(curInp, curNode.getName()) #Consequence of topological sorting of the TF graph (assignedVarAST, curAsts) = generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraInfoDict) for outputName, curAst in curAsts.items(): mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : curNode.getOp(), @@ -152,6 +152,48 @@ def simplifyGraph(graph, sizeInfo): newNodes.append(curNode) graph.setNodesList(newNodes) +def arrange_input_before_output(graph): + allNodes = graph.getAllNodesRef() + visited = set() + already_sorted = True + for curNode in allNodes: + visited.add(curNode.getName()) + for inp in curNode.getInputsRef(): + if inp not in visited: + already_sorted = False + break + + if already_sorted: + return + + adjList = { i : [] for i in range(len(allNodes))} + position = { node.getName() : i for i,node in enumerate(allNodes)} + for i, curNode in enumerate(allNodes): + inputs = curNode.getInputsRef() + for inp in inputs: + adjList[position[inp]].append(i) + + no_nodes = len(allNodes) + visited = [False] * no_nodes + final_order = [] + + def topo_sort(v): + visited[v] = True + for i in adjList[v]: + if visited[i] == False: + topo_sort(i) + final_order.insert(0,v) + + for i in range(no_nodes): + if visited[i] == False: + topo_sort(i) + + assert len(final_order) == no_nodes, "Lost some nodes while sorting" + newNodes = [allNodes[i] for i in final_order] + graph.setNodesList(newNodes) + return + + def process_tf_graph(filename): sys.setrecursionlimit(10000) @@ -164,6 +206,8 @@ def process_tf_graph(filename): with open(graphFileName) as file: graph.readFromFilePointer(file) + arrange_input_before_output(graph) + # Read the sizeInfo also sizeInfoFileName = os.path.join(folderName, 'sizeInfo.mtdata') sizeInfo = readSizeInfo(sizeInfoFileName) From c917d195bb65ca70daf8fa2db1e49f67c3a8a4a8 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Fri, 22 Jan 2021 10:40:39 +0530 Subject: [PATCH 45/72] Make operatorsymboldict alphabetical --- Athos/SeeDot/AST/AST.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/Athos/SeeDot/AST/AST.py b/Athos/SeeDot/AST/AST.py index ca747045..e8a5fd12 100644 --- a/Athos/SeeDot/AST/AST.py +++ b/Athos/SeeDot/AST/AST.py @@ -26,23 +26,23 @@ OperatorsSymbolDict = { "ADD": '+', - "SUB": '-', - "MUL": '*', + "ClearMemPublic": 'clearmempublic' + "ClearMemSecret": 'clearmemsecret', "CONV": '#', "CONVTRANSPOSE": "#T", #ConvTranspose - "RELU": 'relu', - "TANH": 'tanh', - "SIGMOID": 'sigmoid', - "SQRT": 'sqrt', - "RSQRT": 'rsqrt', - "Equal": '==', - "ElemWiseMul":'.*', "ElemWiseDiv": './', + "ElemWiseMul":'.*', + "Equal": '==', "Floor": 'floor', - "Shape": 'shape', "Mean": 'mean', - "ClearMemSecret": 'clearmemsecret', - "ClearMemPublic": 'clearmempublic' + "MUL": '*', + "RELU": 'relu', + "RSQRT": 'rsqrt', + "Shape": 'shape', + "SIGMOID": 'sigmoid', + "SQRT": 'sqrt', + "SUB": '-', + "TANH": 'tanh', } class Party(Enum): From a4915d8508c7170a5a0bdce8bdd305e843325d3b Mon Sep 17 00:00:00 2001 From: Bhatu Date: Fri, 22 Jan 2021 11:33:12 +0530 Subject: [PATCH 46/72] Fixes bug introduced due to topological sorting. We were not maintaing the partial order between variable and placeholder nodes. So input was being read in the wrong order in the output program. Now we maintain this partial order while doing topological sorting. --- Athos/SeeDot/AST/AST.py | 2 +- Athos/TFCompiler/Graph.py | 2 +- Athos/TFCompiler/ProcessTFGraph.py | 28 +++++++++++++++++++++------- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/Athos/SeeDot/AST/AST.py b/Athos/SeeDot/AST/AST.py index e8a5fd12..a309a3cf 100644 --- a/Athos/SeeDot/AST/AST.py +++ b/Athos/SeeDot/AST/AST.py @@ -26,7 +26,7 @@ OperatorsSymbolDict = { "ADD": '+', - "ClearMemPublic": 'clearmempublic' + "ClearMemPublic": 'clearmempublic', "ClearMemSecret": 'clearmemsecret', "CONV": '#', "CONVTRANSPOSE": "#T", #ConvTranspose diff --git a/Athos/TFCompiler/Graph.py b/Athos/TFCompiler/Graph.py index ec8440df..b610aa95 100644 --- a/Athos/TFCompiler/Graph.py +++ b/Athos/TFCompiler/Graph.py @@ -656,5 +656,5 @@ def __getitem__(self, opName): return self.__Nodes[opName] def print(self): - for _, curNode in self.__Nodes.items(): + for curNode in self.__NodesLi: curNode.print() diff --git a/Athos/TFCompiler/ProcessTFGraph.py b/Athos/TFCompiler/ProcessTFGraph.py index 94fa359b..545cd53d 100644 --- a/Athos/TFCompiler/ProcessTFGraph.py +++ b/Athos/TFCompiler/ProcessTFGraph.py @@ -97,17 +97,19 @@ def readSizeInfo(fileName): return sizeInfo # Since later on in the pipeline, the placeholder nodes which come up as cin statements -# are to be excluded from the timing calculation, output all such PlaceHolder nodes together first. -# This doesn't violate the topological ordering because all such PlaceHolder nodes are leaf nodes -# in the graph. +# are to be excluded from the timing calculation, output all such PlaceHolder nodes together first. +# This doesn't violate the topological ordering because all such PlaceHolder nodes are leaf nodes +# in the graph. +# This however extends live ranges of inputs and increases peak memory usage. +# This also maintains the partial order between placeholder/variable nodes def prefixAllPlaceHolderNodes(graph): allNodes = graph.getAllNodesRef() placeHolderNodes = [] + variableNodes = [] remNodes = [] for curNode in allNodes: - if (curNode.getOp() == "Placeholder" or curNode.getOp() == "VariableV2"): - # Assert this is indeed a leaf node - assert(len(curNode.getInputsRef()) == 0) + if curNode.getOp() in ["Placeholder", "VariableV2"]: + assert(len(curNode.getInputsRef()) == 0) # Assert this is indeed a leaf node placeHolderNodes.append(curNode) else: remNodes.append(curNode) @@ -152,6 +154,10 @@ def simplifyGraph(graph, sizeInfo): newNodes.append(curNode) graph.setNodesList(newNodes) +# We have to process all input nodes before output nodes. +# However we cannot change the partial order of the placeholder and variable nodes. +# The model weights are dumped from tensorflow in the original graphdef order and if +# we don't adhere to that, different inputs will be read by the program. def arrange_input_before_output(graph): allNodes = graph.getAllNodesRef() visited = set() @@ -163,6 +169,7 @@ def arrange_input_before_output(graph): already_sorted = False break + # True almost all the time if already_sorted: return @@ -173,6 +180,13 @@ def arrange_input_before_output(graph): for inp in inputs: adjList[position[inp]].append(i) + # Additionally create edges between all placeholder and variable nodes + nodes_seen = [] + for i, curNode in reversed(list(enumerate(allNodes))): + if curNode.getOp() in ["Placeholder", "VariableV2"]: + adjList[i].extend(nodes_seen) + nodes_seen.append(i) + no_nodes = len(allNodes) visited = [False] * no_nodes final_order = [] @@ -214,7 +228,7 @@ def process_tf_graph(filename): # Tensorflow graph level optimisations simplifyGraph(graph, sizeInfo) - # Place all PlaceHolder nodes together at the beginning + # Place all PlaceHolder and variable nodes together at the beginning prefixAllPlaceHolderNodes(graph) # Re-format the input names of nodes From d61c6a1d1b105597d6934890adb6ae89f64d868c Mon Sep 17 00:00:00 2001 From: Bhatu Date: Fri, 22 Jan 2021 11:57:24 +0530 Subject: [PATCH 47/72] [!!MEGA FORMATTING!!] Reformat all Athos code with black. This messes up git blame. Use git blame xyz --ignore-revs-file Athos/.git-blame-ignore-revs Or permanently do git config blame.ignoreRevsFile Athos/.git-blame-ignore-revs The ignore revs file is created in the next commit as we need the hash of this commit. --- Athos/.git-blame-ignore-revs | 1 + Athos/CompileSampleNetworks.py | 22 +- Athos/CompileTFGraph.py | 20 +- Athos/CompilerScripts/change_onnx_output.py | 148 +- .../comparison_scripts/compare_output.py | 21 +- .../comparison_scripts/convert_scale.py | 31 +- .../comparison_scripts/convert_to_signed.py | 26 +- Athos/CompilerScripts/compile_tf.py | 238 +- Athos/CompilerScripts/compile_tf_graph.py | 260 +- .../CompilerScripts/convert_keras_to_onnx.py | 29 +- Athos/CompilerScripts/convert_keras_to_tf.py | 29 +- Athos/CompilerScripts/create_tf_input.py | 149 +- Athos/CompilerScripts/get_pred_tf_graph.py | 160 +- Athos/CompilerScripts/grappler.py | 318 +- Athos/CompilerScripts/parse_config.py | 249 +- .../preprocess_frozen_tf_graph.py | 161 +- Athos/CompilerScripts/tf_graph_io.py | 59 +- Athos/CompilerScripts/tf_graph_trans.py | 275 +- Athos/HelperScripts/Confirm_preprocessing.py | 38 +- .../HelperScripts/Convert_WnId_To_TrainId.py | 28 +- Athos/HelperScripts/FindAccuracy.py | 132 +- Athos/HelperScripts/FindAccuracy_Porthos.py | 165 +- Athos/HelperScripts/FindAccuracy_TF.py | 122 +- Athos/HelperScripts/Random_Image_Selection.py | 22 +- Athos/HelperScripts/Scale_img_and_model.py | 99 +- Athos/HelperScripts/nn_maxmintest.py | 44 +- Athos/Networks/ChestXRay/ChestXRay_tf_main.py | 189 +- .../DenseNet_main_float_acc.py | 88 +- Athos/Networks/DenseNet/DenseNet_main.py | 188 +- .../DenseNet_preprocess_main.py | 255 +- .../DenseNet_preprocessing.py | 609 ++-- Athos/Networks/DenseNet/densenet.py | 425 ++- Athos/Networks/DenseNet/nets_factory.py | 75 +- .../Lenet/lenetLarge_mnist_inference.py | 284 +- .../Networks/Lenet/lenetLarge_mnist_train.py | 259 +- .../Lenet/lenetSmall_mnist_inference.py | 288 +- .../Networks/Lenet/lenetSmall_mnist_train.py | 264 +- .../LogisticRegressionInfer.py | 125 +- .../LogisticRegressionTrain.py | 30 +- .../Networks/OtherBenchmarks/MiniONN_CIFAR.py | 98 +- .../OtherBenchmarks/resnet32_cifar100.py | 312 +- .../ResNet_main_float_acc.py | 81 +- .../ResNet_preprocess_main.py | 717 ++-- .../imagenet_preprocessing.py | 349 +- Athos/Networks/ResNet/ResNet_main.py | 372 +- Athos/Networks/ResNet/Resnet_Model.py | 1012 +++--- Athos/Networks/SecureNNBenchmarks/NetworkA.py | 90 +- Athos/Networks/SecureNNBenchmarks/NetworkB.py | 86 +- Athos/Networks/SecureNNBenchmarks/NetworkC.py | 86 +- Athos/Networks/SecureNNBenchmarks/NetworkD.py | 72 +- .../SqueezeNetCIFAR10/Squeezenet_model.py | 1250 ++++--- Athos/Networks/SqueezeNetCIFAR10/Util.py | 225 +- .../SqueezeNet_main_float_acc.py | 79 +- .../SqNetImgNet_preprocess_main.py | 105 +- .../SqueezeNetImgNet/squeezenet_main.py | 360 +- Athos/ONNXCompiler/ONNXNodesAST.py | 2173 ++++++----- Athos/ONNXCompiler/common.py | 134 +- Athos/ONNXCompiler/create_input.py | 173 +- Athos/ONNXCompiler/onnx_run.py | 52 +- Athos/ONNXCompiler/onnx_run_tf.py | 119 +- Athos/ONNXCompiler/process_onnx.py | 330 +- Athos/ONNXCompiler/test/test.py | 504 +-- Athos/SeeDot/AST/AST.py | 694 ++-- Athos/SeeDot/AST/ASTVisitor.py | 207 +- Athos/SeeDot/AST/IRBuilderAST.py | 20 +- Athos/SeeDot/AST/MtdAST.py | 141 +- Athos/SeeDot/AST/PrintAST.py | 230 +- Athos/SeeDot/Codegen/CodegenBase.py | 395 +- Athos/SeeDot/Codegen/EzPC.py | 242 +- Athos/SeeDot/Compiler.py | 251 +- Athos/SeeDot/IR/IR.py | 633 ++-- Athos/SeeDot/IR/IRBuilderCSF.py | 3195 ++++++++++------- Athos/SeeDot/IR/IRUtil.py | 403 ++- .../SeeDot/Optimizations/GarbageCollector.py | 446 +-- Athos/SeeDot/Optimizations/ReluMaxpoolOpti.py | 46 +- Athos/SeeDot/SeeDot.py | 184 +- Athos/SeeDot/Type.py | 998 ++--- Athos/SeeDot/Util.py | 342 +- Athos/SeeDot/Writer.py | 31 +- Athos/TFCompiler/DumpTFMtData.py | 361 +- Athos/TFCompiler/Graph.py | 581 +-- Athos/TFCompiler/ProcessTFGraph.py | 446 +-- Athos/TFCompiler/TFNodesAST.py | 2109 +++++++---- Athos/tests/conftest.py | 4 +- Athos/tests/tf/unittests/test_arith_binops.py | 55 +- Athos/tests/tf/unittests/test_batchnorm.py | 15 +- Athos/tests/tf/unittests/test_non_linear.py | 5 +- .../tf/unittests/test_shape_manipulation.py | 8 +- Athos/tests/tf/unittests/test_unaryops.py | 8 +- 89 files changed, 15333 insertions(+), 11421 deletions(-) create mode 100644 Athos/.git-blame-ignore-revs diff --git a/Athos/.git-blame-ignore-revs b/Athos/.git-blame-ignore-revs new file mode 100644 index 00000000..45713d37 --- /dev/null +++ b/Athos/.git-blame-ignore-revs @@ -0,0 +1 @@ +da9f654919ac47e08e166b389653b4abfa69900a diff --git a/Athos/CompileSampleNetworks.py b/Athos/CompileSampleNetworks.py index cec17087..fa45a453 100644 --- a/Athos/CompileSampleNetworks.py +++ b/Athos/CompileSampleNetworks.py @@ -79,7 +79,7 @@ def generate_code(params, debug=False): "ResNet", "DenseNet", "SqueezeNetImgNet", - "SqueezeNetCIFAR10" + "SqueezeNetCIFAR10", ], "Network must be any of ResNet/DenseNet/SqueezeNetImgNet/SqueezeNetCIFAR10" scale = 12 if params["scale"] is None else params["scale"] bitlength = 64 if params["bitlength"] is None else params["bitlength"] @@ -145,10 +145,10 @@ def generate_code(params, debug=False): ezpc_abs_path = os.path.join(model_abs_dir, ezpc_file_name) seedot_args = "" - seedot_args += "--astFile \"{}/astOutput.pkl\" --consSF {} ".format( + seedot_args += '--astFile "{}/astOutput.pkl" --consSF {} '.format( model_abs_dir, scale ) - seedot_args += "--bitlen {} --outputFileName \"{}\" ".format(bitlength, ezpc_abs_path) + seedot_args += '--bitlen {} --outputFileName "{}" '.format(bitlength, ezpc_abs_path) seedot_args += "--disableAllOpti {} ".format(disable_all_hlil_opts) seedot_args += "--disableRMO {} ".format(disable_relu_maxpool_opts) seedot_args += "--disableLivenessOpti {} ".format(disable_garbage_collection) @@ -181,15 +181,15 @@ def generate_code(params, debug=False): post = "" temp = os.path.join(model_abs_dir, "temp.ezpc") os.system( - "cat \"{pre}\" \"{common}\" \"{post}\" \"{ezpc}\"> \"{temp}\"".format( + 'cat "{pre}" "{common}" "{post}" "{ezpc}"> "{temp}"'.format( pre=pre, common=common, post=post, ezpc=ezpc_abs_path, temp=temp ) ) - os.system("mv \"{temp}\" \"{ezpc}\"".format(temp=temp, ezpc=ezpc_abs_path)) + os.system('mv "{temp}" "{ezpc}"'.format(temp=temp, ezpc=ezpc_abs_path)) ezpc_dir = os.path.join(athos_dir, "../EzPC/EzPC/") # Copy generated code to the ezpc directory - os.system("cp \"{ezpc}\" \"{ezpc_dir}\"".format(ezpc=ezpc_abs_path, ezpc_dir=ezpc_dir)) + os.system('cp "{ezpc}" "{ezpc_dir}"'.format(ezpc=ezpc_abs_path, ezpc_dir=ezpc_dir)) os.chdir(ezpc_dir) ezpc_args = "" ezpc_args += "--bitlen {bl} --codegen {target} --disable-tac ".format( @@ -205,12 +205,14 @@ def generate_code(params, debug=False): ezpc_args += "--sf {} ".format(scale) os.system( - "eval `opam config env`; ./ezpc.sh \"{}\" ".format(ezpc_file_name) + ezpc_args + 'eval `opam config env`; ./ezpc.sh "{}" '.format(ezpc_file_name) + ezpc_args ) os.system( - "mv \"{output}\" \"{model_dir}\" ".format(output=output_name, model_dir=model_abs_dir) + 'mv "{output}" "{model_dir}" '.format( + output=output_name, model_dir=model_abs_dir + ) ) - os.system("rm \"{}\"".format(ezpc_file_name)) + os.system('rm "{}"'.format(ezpc_file_name)) output_file = os.path.join(model_abs_dir, output_name) print( @@ -235,7 +237,7 @@ def generate_code(params, debug=False): if target in ["CPP", "CPPRING"]: os.system( - "g++ {opt_flag} -w \"{file}\" -o \"{output}\"".format( + 'g++ {opt_flag} -w "{file}" -o "{output}"'.format( file=output_file, output=program_path, opt_flag=opt_flag ) ) diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index 98b8dd16..b5d6605f 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -139,10 +139,10 @@ def generate_code(params, debug=False): ezpc_abs_path = os.path.join(model_abs_dir, ezpc_file_name) seedot_args = "" - seedot_args += "--astFile \"{}/astOutput.pkl\" --consSF {} ".format( + seedot_args += '--astFile "{}/astOutput.pkl" --consSF {} '.format( model_abs_dir, scale ) - seedot_args += "--bitlen {} --outputFileName \"{}\" ".format(bitlength, ezpc_abs_path) + seedot_args += '--bitlen {} --outputFileName "{}" '.format(bitlength, ezpc_abs_path) seedot_args += "--disableAllOpti {} ".format(disable_all_hlil_opts) seedot_args += "--disableRMO {} ".format(disable_relu_maxpool_opts) seedot_args += "--disableLivenessOpti {} ".format(disable_garbage_collection) @@ -175,15 +175,15 @@ def generate_code(params, debug=False): post = "" temp = os.path.join(model_abs_dir, "temp.ezpc") os.system( - "cat \"{pre}\" \"{common}\" \"{post}\" \"{ezpc}\"> \"{temp}\"".format( + 'cat "{pre}" "{common}" "{post}" "{ezpc}"> "{temp}"'.format( pre=pre, common=common, post=post, ezpc=ezpc_abs_path, temp=temp ) ) - os.system("mv \"{temp}\" \"{ezpc}\"".format(temp=temp, ezpc=ezpc_abs_path)) + os.system('mv "{temp}" "{ezpc}"'.format(temp=temp, ezpc=ezpc_abs_path)) ezpc_dir = os.path.join(athos_dir, "../EzPC/EzPC/") # Copy generated code to the ezpc directory - os.system("cp \"{ezpc}\" \"{ezpc_dir}\"".format(ezpc=ezpc_abs_path, ezpc_dir=ezpc_dir)) + os.system('cp "{ezpc}" "{ezpc_dir}"'.format(ezpc=ezpc_abs_path, ezpc_dir=ezpc_dir)) os.chdir(ezpc_dir) ezpc_args = "" ezpc_args += "--bitlen {bl} --codegen {target} --disable-tac ".format( @@ -199,12 +199,14 @@ def generate_code(params, debug=False): ezpc_args += "--sf {} ".format(scale) os.system( - "eval `opam config env`; ./ezpc.sh \"{}\" ".format(ezpc_file_name) + ezpc_args + 'eval `opam config env`; ./ezpc.sh "{}" '.format(ezpc_file_name) + ezpc_args ) os.system( - "mv \"{output}\" \"{model_dir}\" ".format(output=output_name, model_dir=model_abs_dir) + 'mv "{output}" "{model_dir}" '.format( + output=output_name, model_dir=model_abs_dir + ) ) - os.system("rm \"{}\"".format(ezpc_file_name)) + os.system('rm "{}"'.format(ezpc_file_name)) output_file = os.path.join(model_abs_dir, output_name) print("Compiling generated code to {target} target".format(target=target)) @@ -220,7 +222,7 @@ def generate_code(params, debug=False): opt_flag = "-O3" if target in ["CPP", "CPPRING"]: os.system( - "g++ {opt_flag} -w \"{file}\" -o \"{output}\"".format( + 'g++ {opt_flag} -w "{file}" -o "{output}"'.format( file=output_file, output=program_path, opt_flag=opt_flag ) ) diff --git a/Athos/CompilerScripts/change_onnx_output.py b/Athos/CompilerScripts/change_onnx_output.py index ba86f1ce..a22daeba 100644 --- a/Athos/CompilerScripts/change_onnx_output.py +++ b/Athos/CompilerScripts/change_onnx_output.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import onnx import onnxruntime @@ -31,63 +31,71 @@ model_name = "shufflenet_may17.onnx" output_model_name = "processed_" + model_name -inputs = ['data'] -nodes_to_remove = ['LabelSelector', 'LabelIndexExtractor', 'ZipMap', - 'activation37'] -new_output_names = ['fc'] +inputs = ["data"] +nodes_to_remove = ["LabelSelector", "LabelIndexExtractor", "ZipMap", "activation37"] +new_output_names = ["fc"] batch_size = 1 + def fix_shape(shape_list, batch_size): - if 'None' not in shape_list: - return shape_list - else: - shape_list[0] = batch_size - assert ('None' not in shape_list) , """Other than batch size there are input + if "None" not in shape_list: + return shape_list + else: + shape_list[0] = batch_size + assert ( + "None" not in shape_list + ), """Other than batch size there are input params with unkown dimension""" - return shape_list + return shape_list + def fix_inp_shape(inp, batch_size): - if inp.type.tensor_type.shape.dim[0].dim_param == 'None': - inp.type.tensor_type.shape.dim[0].dim_value = batch_size - return + if inp.type.tensor_type.shape.dim[0].dim_param == "None": + inp.type.tensor_type.shape.dim[0].dim_value = batch_size + return + def get_np_type_from_onnxruntime(typ_str): - np_types = { - 'tensor(float)' : np.float32, - 'tensor(float64)' : np.float64, - 'tensor(int)' : np.int32, - 'tensor(int64)' : np.int64 - } - return np_types[typ_str] + np_types = { + "tensor(float)": np.float32, + "tensor(float64)": np.float64, + "tensor(int)": np.int32, + "tensor(int64)": np.int64, + } + return np_types[typ_str] + def get_onnx_type(arr): - onnx_types = { - np.float32 : TensorProto.FLOAT, - np.float64 : TensorProto.DOUBLE, - np.int32 : TensorProto.INT32, - np.int64 : TensorProto.INT64 - } - return onnx_types[arr.dtype.type] - + onnx_types = { + np.float32: TensorProto.FLOAT, + np.float64: TensorProto.DOUBLE, + np.int32: TensorProto.INT32, + np.int64: TensorProto.INT64, + } + return onnx_types[arr.dtype.type] + model = onnx.load(model_name) # 1. Inputs to remove # Inputs to dead nodes should not show up as inputs for the model # and also not in the initialization list. -inputs_to_remove = [ inp for i in model.graph.node - if i.name in nodes_to_remove for inp in i.input ] -new_inputs = [ i for i in model.graph.input if i.name not in inputs_to_remove ] +inputs_to_remove = [ + inp for i in model.graph.node if i.name in nodes_to_remove for inp in i.input +] +new_inputs = [i for i in model.graph.input if i.name not in inputs_to_remove] # Fix batch size fix_inp_shape(new_inputs[0], batch_size) # 2. Remove their initializers -new_initializers = [ init for init in model.graph.initializer - if init.name not in nodes_to_remove - and init.name not in inputs_to_remove ] +new_initializers = [ + init + for init in model.graph.initializer + if init.name not in nodes_to_remove and init.name not in inputs_to_remove +] # 3. Remove nodes -new_nodes = [ n for n in model.graph.node if n.name not in nodes_to_remove ] +new_nodes = [n for n in model.graph.node if n.name not in nodes_to_remove] # Get Ouput Tensor Types to create ValueInfo for output info @@ -95,45 +103,49 @@ def get_onnx_type(arr): temp_model = ModelProto() temp_model.CopyFrom(model) for i in new_output_names: - op = ValueInfoProto() - op.name = i - temp_model.graph.output.append(op) -onnx.save(temp_model, '__temp.onnx') -sess = onnxruntime.InferenceSession('__temp.onnx') + op = ValueInfoProto() + op.name = i + temp_model.graph.output.append(op) +onnx.save(temp_model, "__temp.onnx") +sess = onnxruntime.InferenceSession("__temp.onnx") sess_inps = sess.get_inputs() input_dict = {} for i in sess_inps: - shape = fix_shape(i.shape, batch_size) - typ = get_np_type_from_onnxruntime(i.type) - input_dict[i.name] = np.random.rand(*shape).astype(typ) + shape = fix_shape(i.shape, batch_size) + typ = get_np_type_from_onnxruntime(i.type) + input_dict[i.name] = np.random.rand(*shape).astype(typ) output_tensors = sess.run(new_output_names, input_dict) if os.path.exists("__temp.onnx"): - os.remove("__temp.onnx") + os.remove("__temp.onnx") # 4. Create new output list -new_outputs = [] -for i in range(0,len(new_output_names)): - name = new_output_names[i] - typ = get_onnx_type(output_tensors[i]) - shape = output_tensors[i].shape - val_info = helper.make_tensor_value_info(name, typ, shape) - new_outputs.append(val_info) - -new_graph = helper.make_graph(new_nodes, - model.graph.name, - new_inputs, - new_outputs, - initializer=new_initializers, - doc_string=model.graph.doc_string, - value_info=model.graph.value_info) -new_model = helper.make_model(new_graph, - ir_version=model.ir_version, - doc_string=model.doc_string, - model_version=model.model_version, - domain=model.domain, - producer_name='MPCOpRemover') +new_outputs = [] +for i in range(0, len(new_output_names)): + name = new_output_names[i] + typ = get_onnx_type(output_tensors[i]) + shape = output_tensors[i].shape + val_info = helper.make_tensor_value_info(name, typ, shape) + new_outputs.append(val_info) + +new_graph = helper.make_graph( + new_nodes, + model.graph.name, + new_inputs, + new_outputs, + initializer=new_initializers, + doc_string=model.graph.doc_string, + value_info=model.graph.value_info, +) +new_model = helper.make_model( + new_graph, + ir_version=model.ir_version, + doc_string=model.doc_string, + model_version=model.model_version, + domain=model.domain, + producer_name="MPCOpRemover", +) new_model.metadata_props.extend(model.metadata_props) new_model.opset_import.pop() new_model.opset_import.extend(model.opset_import) -onnx.save(new_model, 'processed_'+model_name) +onnx.save(new_model, "processed_" + model_name) diff --git a/Athos/CompilerScripts/comparison_scripts/compare_output.py b/Athos/CompilerScripts/comparison_scripts/compare_output.py index 260093e5..5fe57794 100644 --- a/Athos/CompilerScripts/comparison_scripts/compare_output.py +++ b/Athos/CompilerScripts/comparison_scripts/compare_output.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,26 +20,31 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import numpy as np import sys + def extract_txt_to_numpy_array(file, sf): - f = open(file, 'r') - op = [float(int(line.rstrip()))/(2**sf) for line in f] + f = open(file, "r") + op = [float(int(line.rstrip())) / (2 ** sf) for line in f] f.close() return np.array(op, dtype=np.float32) + def extract_float_txt_to_numpy_array(file): - f = open(file, 'r') + f = open(file, "r") op = [float(line.rstrip()) for line in f] f.close() return np.array(op, dtype=np.float32) + if __name__ == "__main__": - if (len(sys.argv) != 5): - print("Usage: compare_output.py floating_point.txt fixed_point.txt SCALING_FACTOR PRECISION") - assert(len(sys.argv) == 5) + if len(sys.argv) != 5: + print( + "Usage: compare_output.py floating_point.txt fixed_point.txt SCALING_FACTOR PRECISION" + ) + assert len(sys.argv) == 5 sf = int(sys.argv[3]) inp1 = extract_float_txt_to_numpy_array(sys.argv[1]) inp2 = extract_txt_to_numpy_array(sys.argv[2], sf) diff --git a/Athos/CompilerScripts/comparison_scripts/convert_scale.py b/Athos/CompilerScripts/comparison_scripts/convert_scale.py index 228428c2..1bf5f183 100644 --- a/Athos/CompilerScripts/comparison_scripts/convert_scale.py +++ b/Athos/CompilerScripts/comparison_scripts/convert_scale.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,17 +20,20 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import sys -if __name__ == '__main__': - assert(len(sys.argv) == 4) - file_name = sys.argv[1] - input_scale = int(sys.argv[2]) - output_scale = int(sys.argv[3]) - output_file_name = file_name + '_' + str(output_scale) - output_file = open(output_file_name, "w") - with open(file_name, "r") as a_file: - for line in a_file: - output = int((float(int(line.strip()))/(2**input_scale)) * (2**output_scale)) - output_file.write(str(output) + '\n') - output_file.close() + +if __name__ == "__main__": + assert len(sys.argv) == 4 + file_name = sys.argv[1] + input_scale = int(sys.argv[2]) + output_scale = int(sys.argv[3]) + output_file_name = file_name + "_" + str(output_scale) + output_file = open(output_file_name, "w") + with open(file_name, "r") as a_file: + for line in a_file: + output = int( + (float(int(line.strip())) / (2 ** input_scale)) * (2 ** output_scale) + ) + output_file.write(str(output) + "\n") + output_file.close() diff --git a/Athos/CompilerScripts/comparison_scripts/convert_to_signed.py b/Athos/CompilerScripts/comparison_scripts/convert_to_signed.py index 6b9808e5..0b88e6e7 100644 --- a/Athos/CompilerScripts/comparison_scripts/convert_to_signed.py +++ b/Athos/CompilerScripts/comparison_scripts/convert_to_signed.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,18 +20,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import sys if __name__ == "__main__": - assert(len(sys.argv) == 4) - inp_fname = sys.argv[1] - out_fname = sys.argv[2] - bitlen = int(sys.argv[3]) - f = open(inp_fname, 'r') - op = [(int(line.rstrip())) for line in f] - f.close() - f = open(out_fname, 'w') - for i in op: - f.write(str( i if (i<2**(bitlen-1)) else i - 2**bitlen) + '\n') - f.close() + assert len(sys.argv) == 4 + inp_fname = sys.argv[1] + out_fname = sys.argv[2] + bitlen = int(sys.argv[3]) + f = open(inp_fname, "r") + op = [(int(line.rstrip())) for line in f] + f.close() + f = open(out_fname, "w") + for i in op: + f.write(str(i if (i < 2 ** (bitlen - 1)) else i - 2 ** bitlen) + "\n") + f.close() diff --git a/Athos/CompilerScripts/compile_tf.py b/Athos/CompilerScripts/compile_tf.py index 11e07ec8..044e9fbb 100644 --- a/Athos/CompilerScripts/compile_tf.py +++ b/Athos/CompilerScripts/compile_tf.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import argparse import os.path import json @@ -43,147 +43,151 @@ def get_graph_from(graph_def): - with tf.Graph().as_default() as graph: - tf.import_graph_def(graph_def, name="") - return graph + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name="") + return graph def check_operation_exists(graph, op): - op_list = [i.name for i in graph.get_operations()] - return op in op_list + op_list = [i.name for i in graph.get_operations()] + return op in op_list def tensors_exist(graph, tensor_names): - op_list = [i.name for i in graph.get_operations()] - for i in tensor_names: - assert i in op_list, "input " + i + " does not exist in the graph" - return True + op_list = [i.name for i in graph.get_operations()] + for i in tensor_names: + assert i in op_list, "input " + i + " does not exist in the graph" + return True def set_input_shapes(graph, input_t_info): - tensor_names = input_t_info.keys() - assert tensors_exist(graph, tensor_names) + tensor_names = input_t_info.keys() + assert tensors_exist(graph, tensor_names) - graph_def = graph.as_graph_def() - inputs = [i for i in graph.get_operations() if i.type == "Placeholder"] + graph_def = graph.as_graph_def() + inputs = [i for i in graph.get_operations() if i.type == "Placeholder"] - input_map = {} - with tf.Graph().as_default() as new_graph: - for i in inputs: - if i.name not in input_t_info: - continue - shape = input_t_info[i.name] - input_map[i.name] = tf.compat.v1.placeholder( - i.get_attr("dtype"), shape=shape, name=i.name - ) - tf.import_graph_def(graph_def, input_map=input_map, name="") - return new_graph + input_map = {} + with tf.Graph().as_default() as new_graph: + for i in inputs: + if i.name not in input_t_info: + continue + shape = input_t_info[i.name] + input_map[i.name] = tf.compat.v1.placeholder( + i.get_attr("dtype"), shape=shape, name=i.name + ) + tf.import_graph_def(graph_def, input_map=input_map, name="") + return new_graph def get_tensor(graph, name): - return graph.get_operation_by_name(name).outputs[0] + return graph.get_operation_by_name(name).outputs[0] def infer_input_info(graph): - input_t_info = {} - inputs = [i for i in graph.get_operations() if i.type == "Placeholder"] - for i in inputs: - input_t = i.outputs[0] - if input_t.shape.dims == None: - inp_shape = [] - else: - inp_shape = input_t.shape.as_list() - assert None not in inp_shape, "Placeholder node " + i.name + " has unknown shape. Please specify name and shape in config" - input_t_info[i.name] = inp_shape - return input_t_info + input_t_info = {} + inputs = [i for i in graph.get_operations() if i.type == "Placeholder"] + for i in inputs: + input_t = i.outputs[0] + if input_t.shape.dims == None: + inp_shape = [] + else: + inp_shape = input_t.shape.as_list() + assert None not in inp_shape, ( + "Placeholder node " + + i.name + + " has unknown shape. Please specify name and shape in config" + ) + input_t_info[i.name] = inp_shape + return input_t_info # Generates the computation graph and tensor size metadata and saves them in # the model directory. # Optionaly dumps model weights as fixedpt in specified scaling factor def compile(model_fname, input_t_info, output_t_names, scaling_factor, save_weights): - model_name = os.path.basename(model_fname)[:-3] - print("Loading tf graph ", model_fname) - graph = tf_graph_io.load_pb(model_fname) - assert tensors_exist(graph, output_t_names) - - if input_t_info == {}: - input_t_info = infer_input_info(graph) - else: - tensors_exist(graph, list(input_t_info.keys())) - graph = set_input_shapes(graph, input_t_info) - input_t_names = list(input_t_info.keys()) - graph_def = grappler.optimize(graph, input_t_names, output_t_names) - graph_def = grappler.convert_consts_to_var(graph_def) - graph = get_graph_from(graph_def) - - feed_dict = {} - for name, shape in input_t_info.items(): - tensor = get_tensor(graph, name) - zeros = np.zeros(shape) - feed_dict[tensor] = zeros - - cwd = os.getcwd() - with graph.as_default(): - with tf.compat.v1.Session() as sess: - # Run initializers generated by preprocessing - if check_operation_exists(graph, "init_constvars"): - sess.run(graph.get_operation_by_name("init_constvars")) - sess.run(tf.compat.v1.global_variables_initializer()) - model_dir = os.path.realpath(os.path.dirname(model_fname)) - os.chdir(model_dir) - - # At this stage the graph still has constants embedded in it - # in the assign nodes for variables. We cannot execute the graph without - # these constants. We strip them away in a new graph def which is amenable - # to codegen but leave them in the graph. - optimized_graph_def = DumpTFMtData.strip_variable_init_constants( - graph_def, input_t_names, output_t_names - ) - - tf_graph_io.dump_graph_def_pb( - optimized_graph_def, "optimised_" + model_name + ".pb" - ) - DumpTFMtData.save_graphdef(optimized_graph_def) - DumpTFMtData.save_sizeinfo(optimized_graph_def, sess, feed_dict) - print("Model compilation done.") - weights_path = "" - if save_weights: - weights_fname = ( - model_name - + "_input_weights_fixedpt_scale_" - + str(scaling_factor) - + ".inp" - ) - print( - "\nDumping model weights in ", - model_dir + "/" + weights_fname, - ".\nThese are to be used as input for party which owns the model\n", - ) - DumpTFMtData.save_weights( - optimized_graph_def, sess, feed_dict, weights_fname, scaling_factor - ) - weights_path = os.path.join(model_dir, weights_fname) - os.chdir(cwd) - return weights_path + model_name = os.path.basename(model_fname)[:-3] + print("Loading tf graph ", model_fname) + graph = tf_graph_io.load_pb(model_fname) + assert tensors_exist(graph, output_t_names) + + if input_t_info == {}: + input_t_info = infer_input_info(graph) + else: + tensors_exist(graph, list(input_t_info.keys())) + graph = set_input_shapes(graph, input_t_info) + input_t_names = list(input_t_info.keys()) + graph_def = grappler.optimize(graph, input_t_names, output_t_names) + graph_def = grappler.convert_consts_to_var(graph_def) + graph = get_graph_from(graph_def) + + feed_dict = {} + for name, shape in input_t_info.items(): + tensor = get_tensor(graph, name) + zeros = np.zeros(shape) + feed_dict[tensor] = zeros + + cwd = os.getcwd() + with graph.as_default(): + with tf.compat.v1.Session() as sess: + # Run initializers generated by preprocessing + if check_operation_exists(graph, "init_constvars"): + sess.run(graph.get_operation_by_name("init_constvars")) + sess.run(tf.compat.v1.global_variables_initializer()) + model_dir = os.path.realpath(os.path.dirname(model_fname)) + os.chdir(model_dir) + + # At this stage the graph still has constants embedded in it + # in the assign nodes for variables. We cannot execute the graph without + # these constants. We strip them away in a new graph def which is amenable + # to codegen but leave them in the graph. + optimized_graph_def = DumpTFMtData.strip_variable_init_constants( + graph_def, input_t_names, output_t_names + ) + + tf_graph_io.dump_graph_def_pb( + optimized_graph_def, "optimised_" + model_name + ".pb" + ) + DumpTFMtData.save_graphdef(optimized_graph_def) + DumpTFMtData.save_sizeinfo(optimized_graph_def, sess, feed_dict) + print("Model compilation done.") + weights_path = "" + if save_weights: + weights_fname = ( + model_name + + "_input_weights_fixedpt_scale_" + + str(scaling_factor) + + ".inp" + ) + print( + "\nDumping model weights in ", + model_dir + "/" + weights_fname, + ".\nThese are to be used as input for party which owns the model\n", + ) + DumpTFMtData.save_weights( + optimized_graph_def, sess, feed_dict, weights_fname, scaling_factor + ) + weights_path = os.path.join(model_dir, weights_fname) + os.chdir(cwd) + return weights_path def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--config", required=True, type=str, help="Path to the config file" - ) - args = parser.parse_args() - return args + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", required=True, type=str, help="Path to the config file" + ) + args = parser.parse_args() + return args if __name__ == "__main__": - args = parse_args() - params = parse_config.get_params(args.config) - compile( - params["model_name"], - params["input_tensors"], - params["output_tensors"], - params["scale"], - params["save_weights"], - ) + args = parse_args() + params = parse_config.get_params(args.config) + compile( + params["model_name"], + params["input_tensors"], + params["output_tensors"], + params["scale"], + params["save_weights"], + ) diff --git a/Athos/CompilerScripts/compile_tf_graph.py b/Athos/CompilerScripts/compile_tf_graph.py index 7e0d7187..ca5808ca 100644 --- a/Athos/CompilerScripts/compile_tf_graph.py +++ b/Athos/CompilerScripts/compile_tf_graph.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import tensorflow as tf import numpy as np import argparse @@ -30,110 +30,180 @@ import os.path import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'TFCompiler')) + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "TFCompiler")) import DumpTFMtData from os import path -def check_operation_exists(graph, tensor_name): - op_list = [i.name for i in graph.get_operations()] - return tensor_name in op_list - -def compile(model_fname, input_t_name, output_t_name, scaling_factor, save_weights, input_shape): - if not model_fname.endswith('.pb'): - sys.exit("Please supply a valid tensorflow protobuf model (.pb extension)") - elif not "mpc_processed_" in model_fname: - sys.exit("""Please process model using preprocess_frozen_tf_graph.py. +def check_operation_exists(graph, tensor_name): + op_list = [i.name for i in graph.get_operations()] + return tensor_name in op_list + + +def compile( + model_fname, input_t_name, output_t_name, scaling_factor, save_weights, input_shape +): + if not model_fname.endswith(".pb"): + sys.exit("Please supply a valid tensorflow protobuf model (.pb extension)") + elif not "mpc_processed_" in model_fname: + sys.exit( + """Please process model using preprocess_frozen_tf_graph.py. This will optimise it and generate a new .pb with mpc_processed prefix. -Use that with this script.""") - else: - model_name = os.path.basename(model_fname)[:-3] - - print("Loading processed tf graph ", model_fname) - graph = load_pb(model_fname) - - if not check_operation_exists(graph, output_t_name): - sys.exit(output_t_name + " output does not exist in the graph") - output_t = graph.get_operation_by_name(output_t_name).outputs[0] - - if input_t_name != "": - if not check_operation_exists(graph, input_t_name): - sys.exit(input_t_name + " input does not exist in the graph") - - input_t = graph.get_operation_by_name(input_t_name).outputs[0] - - # Generate random tensor as input - # scalar input - if input_t.shape.dims == None: - inp_shape = [] +Use that with this script.""" + ) else: - inp_shape = input_t.shape.as_list() - if None in inp_shape: - if input_shape == []: - sys.exit("Please supply shape for the input tensor as it is parametric (? dim) for this model. See --help.") + model_name = os.path.basename(model_fname)[:-3] + + print("Loading processed tf graph ", model_fname) + graph = load_pb(model_fname) + + if not check_operation_exists(graph, output_t_name): + sys.exit(output_t_name + " output does not exist in the graph") + output_t = graph.get_operation_by_name(output_t_name).outputs[0] + + if input_t_name != "": + if not check_operation_exists(graph, input_t_name): + sys.exit(input_t_name + " input does not exist in the graph") + + input_t = graph.get_operation_by_name(input_t_name).outputs[0] + + # Generate random tensor as input + # scalar input + if input_t.shape.dims == None: + inp_shape = [] else: - inp_shape = input_shape - rand_inp_t = np.zeros(inp_shape) - - feed_dict = {input_t: rand_inp_t} - else: - # We can collect all placeholder nodes as inputs to the model - inputs = [i for i in graph.get_operations() if i.type=="Placeholder"] - feed_dict = {} - for op in inputs: - input_t = op.outputs[0] - if input_t.shape.dims == None: - inp_shape = [] - else: - inp_shape = input_t.shape.as_list() - if None in inp_shape: - sys.exit("Please supply input names and their shapes for the input tensor as it is parametric (? dim) for this model. See --help.") - rand_inp_t = np.zeros(inp_shape) - feed_dict[input_t] = rand_inp_t - - with graph.as_default(): - with tf.Session() as sess: - # Run initializers generated by preprocessing - if check_operation_exists(graph, 'init_constvars'): - sess.run(graph.get_operation_by_name('init_constvars')) - else: - sess.run(tf.global_variables_initializer()) - # Dump sizeInfo, graphDef mtdata and weight dump in model folder. - model_dir = os.path.realpath(os.path.dirname(model_fname)) - os.chdir(model_dir) - optimized_graph_def = DumpTFMtData.save_graph_metadata(output_t, sess, feed_dict) - print("Model compilation done.") - trainVarsName = [node.name for node in optimized_graph_def.node if node.op == "VariableV2" or node.op == "Variable"] - trainVars = list(map(lambda x : tf.get_default_graph().get_operation_by_name(x).outputs[0] , trainVarsName)) - if save_weights: - DumpTFMtData.updateWeightsForBN(optimized_graph_def, sess) - weights_fname = model_name[len("mpc_processed_"):] + '_input_weights_fixedpt_scale_' + str(scaling_factor) + '.inp' - print("\nDumping model weights in ", weights_fname, ". These are to be used as input for party which owns the model\n") - DumpTFMtData.dumpTrainedWeightsInt(sess, trainVars, weights_fname, scaling_factor, 'w') + inp_shape = input_t.shape.as_list() + if None in inp_shape: + if input_shape == []: + sys.exit( + "Please supply shape for the input tensor as it is parametric (? dim) for this model. See --help." + ) + else: + inp_shape = input_shape + rand_inp_t = np.zeros(inp_shape) + + feed_dict = {input_t: rand_inp_t} + else: + # We can collect all placeholder nodes as inputs to the model + inputs = [i for i in graph.get_operations() if i.type == "Placeholder"] + feed_dict = {} + for op in inputs: + input_t = op.outputs[0] + if input_t.shape.dims == None: + inp_shape = [] + else: + inp_shape = input_t.shape.as_list() + if None in inp_shape: + sys.exit( + "Please supply input names and their shapes for the input tensor as it is parametric (? dim) for this model. See --help." + ) + rand_inp_t = np.zeros(inp_shape) + feed_dict[input_t] = rand_inp_t + + with graph.as_default(): + with tf.Session() as sess: + # Run initializers generated by preprocessing + if check_operation_exists(graph, "init_constvars"): + sess.run(graph.get_operation_by_name("init_constvars")) + else: + sess.run(tf.global_variables_initializer()) + # Dump sizeInfo, graphDef mtdata and weight dump in model folder. + model_dir = os.path.realpath(os.path.dirname(model_fname)) + os.chdir(model_dir) + optimized_graph_def = DumpTFMtData.save_graph_metadata( + output_t, sess, feed_dict + ) + print("Model compilation done.") + trainVarsName = [ + node.name + for node in optimized_graph_def.node + if node.op == "VariableV2" or node.op == "Variable" + ] + trainVars = list( + map( + lambda x: tf.get_default_graph() + .get_operation_by_name(x) + .outputs[0], + trainVarsName, + ) + ) + if save_weights: + DumpTFMtData.updateWeightsForBN(optimized_graph_def, sess) + weights_fname = ( + model_name[len("mpc_processed_") :] + + "_input_weights_fixedpt_scale_" + + str(scaling_factor) + + ".inp" + ) + print( + "\nDumping model weights in ", + weights_fname, + ". These are to be used as input for party which owns the model\n", + ) + DumpTFMtData.dumpTrainedWeightsInt( + sess, trainVars, weights_fname, scaling_factor, "w" + ) + def boolean_string(s): - if s not in {'False', 'True'}: - raise ValueError('Not a valid boolean string') - return s == 'True' + if s not in {"False", "True"}: + raise ValueError("Not a valid boolean string") + return s == "True" + def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--modelName", required=True, type=str, help="Name of processed tensorflow model (mpc_processed*.pb)") - parser.add_argument("--inputTensorName", type=str, default='', help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)") - parser.add_argument("--outputTensorName", required=True, type=str, help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)") - parser.add_argument("--sf", default=12, type=int, help="scaling factor (int)") - parser.add_argument("--saveWeights", type=boolean_string, default=False, help="Dump model weights in fixedpt {True/False}") - parser.add_argument("--inputTensorShape", type=str, default='', help="Comma separated list of shape for input tensor. eg: \"2,245,234,3\"") - args = parser.parse_args() - return args + parser = argparse.ArgumentParser() + parser.add_argument( + "--modelName", + required=True, + type=str, + help="Name of processed tensorflow model (mpc_processed*.pb)", + ) + parser.add_argument( + "--inputTensorName", + type=str, + default="", + help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)", + ) + parser.add_argument( + "--outputTensorName", + required=True, + type=str, + help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)", + ) + parser.add_argument("--sf", default=12, type=int, help="scaling factor (int)") + parser.add_argument( + "--saveWeights", + type=boolean_string, + default=False, + help="Dump model weights in fixedpt {True/False}", + ) + parser.add_argument( + "--inputTensorShape", + type=str, + default="", + help='Comma separated list of shape for input tensor. eg: "2,245,234,3"', + ) + args = parser.parse_args() + return args + def get_shape_list(shape_string): - if shape_string == '': - return [] - return [int(i) for i in shape_string.split(",")] - -if __name__ == '__main__': - args = parse_args() - shape_list = get_shape_list(args.inputTensorShape) - compile(args.modelName, args.inputTensorName, args.outputTensorName, args.sf, args.saveWeights, shape_list) + if shape_string == "": + return [] + return [int(i) for i in shape_string.split(",")] + + +if __name__ == "__main__": + args = parse_args() + shape_list = get_shape_list(args.inputTensorShape) + compile( + args.modelName, + args.inputTensorName, + args.outputTensorName, + args.sf, + args.saveWeights, + shape_list, + ) diff --git a/Athos/CompilerScripts/convert_keras_to_onnx.py b/Athos/CompilerScripts/convert_keras_to_onnx.py index 06651d37..9db7e340 100644 --- a/Athos/CompilerScripts/convert_keras_to_onnx.py +++ b/Athos/CompilerScripts/convert_keras_to_onnx.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,14 +20,14 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import tensorflow as tf import onnx from onnx import shape_inference import keras2onnx -model_filename = 'chest_xray_covid19_model.h5' -output_filename = 'covid_resnet.onnx' +model_filename = "chest_xray_covid19_model.h5" +output_filename = "covid_resnet.onnx" input_h = 224 input_w = 224 @@ -35,25 +35,28 @@ keras_model = tf.keras.models.load_model(model_filename) onnx_model = keras2onnx.convert_keras(keras_model, keras_model.name) + def set_input_dim(onnx_model, idx, val): - onnx_model.graph.input[0].type.tensor_type.shape.dim[idx].dim_value = val + onnx_model.graph.input[0].type.tensor_type.shape.dim[idx].dim_value = val + def get_input_dim(onnx_model, idx): - return onnx_model.graph.input[0].type.tensor_type.shape.dim[idx].dim_value + return onnx_model.graph.input[0].type.tensor_type.shape.dim[idx].dim_value + -#If input dims are parametric we need to materialize the dims to constants +# If input dims are parametric we need to materialize the dims to constants # N H W C -dims = { "n" : 0, "h" : 1, "w" : 2, "c" : 3} +dims = {"n": 0, "h": 1, "w": 2, "c": 3} n = get_input_dim(onnx_model, dims["n"]) h = get_input_dim(onnx_model, dims["h"]) w = get_input_dim(onnx_model, dims["w"]) c = get_input_dim(onnx_model, dims["c"]) -if 0 in [n,h,w,c]: - set_input_dim(onnx_model, dims["n"], 1) - set_input_dim(onnx_model, dims["h"], input_h) - set_input_dim(onnx_model, dims["w"], input_w) +if 0 in [n, h, w, c]: + set_input_dim(onnx_model, dims["n"], 1) + set_input_dim(onnx_model, dims["h"], input_h) + set_input_dim(onnx_model, dims["w"], input_w) fixed_model = onnx.shape_inference.infer_shapes(onnx_model) onnx.checker.check_model(fixed_model) -onnx.save_model(fixed_model, output_filename) +onnx.save_model(fixed_model, output_filename) diff --git a/Athos/CompilerScripts/convert_keras_to_tf.py b/Athos/CompilerScripts/convert_keras_to_tf.py index 36f4f0fb..8643c3f9 100644 --- a/Athos/CompilerScripts/convert_keras_to_tf.py +++ b/Athos/CompilerScripts/convert_keras_to_tf.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,16 +20,21 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import tensorflow as tf -model_filename = 'chest_xray_covid19_model.h5' -output_filename = 'covid_resnet.pb' +model_filename = "chest_xray_covid19_model.h5" +output_filename = "covid_resnet.pb" + def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): graph = session.graph with graph.as_default(): - freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) + freeze_var_names = list( + set(v.op.name for v in tf.global_variables()).difference( + keep_var_names or [] + ) + ) output_names = output_names or [] output_names += [v.op.name for v in tf.global_variables()] input_graph_def = graph.as_graph_def() @@ -37,13 +42,19 @@ def freeze_session(session, keep_var_names=None, output_names=None, clear_device for node in input_graph_def.node: node.device = "" frozen_graph = tf.graph_util.convert_variables_to_constants( - session, input_graph_def, output_names, freeze_var_names) + session, input_graph_def, output_names, freeze_var_names + ) return frozen_graph + tf.keras.backend.set_learning_phase(0) -with tf.keras.utils.CustomObjectScope({'GlorotUniform': tf.keras.initializers.glorot_uniform()}): +with tf.keras.utils.CustomObjectScope( + {"GlorotUniform": tf.keras.initializers.glorot_uniform()} +): model = tf.keras.models.load_model(model_filename) - frozen_graph = freeze_session(tf.keras.backend.get_session(), - output_names=[out.op.name for out in model.outputs]) + frozen_graph = freeze_session( + tf.keras.backend.get_session(), + output_names=[out.op.name for out in model.outputs], + ) tf.train.write_graph(frozen_graph, ".", output_filename, as_text=False) diff --git a/Athos/CompilerScripts/create_tf_input.py b/Athos/CompilerScripts/create_tf_input.py index 064bb3ab..35906045 100644 --- a/Athos/CompilerScripts/create_tf_input.py +++ b/Athos/CompilerScripts/create_tf_input.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,11 +20,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import argparse -import os +import os import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'ONNXCompiler')) + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ONNXCompiler")) from tf_graph_io import * from tf_graph_trans import * @@ -32,63 +33,99 @@ import common import numpy as np + def check_operation_exists(graph, tensor_name): - op_list = [i.name for i in graph.get_operations()] - return tensor_name in op_list - -def gen_random_input(model_fname, input_t_name, scaling_factor, input_shape, dump_numpy): - if not model_fname.endswith('.pb'): - sys.exit("Please supply a valid tensorflow protobuf model (.pb extension)") - else: - model_name = os.path.basename(model_fname)[:-3] - print("Loading processed tf graph ", model_fname) - graph = load_pb(model_fname) - - if not check_operation_exists(graph, input_t_name): - sys.exit(input_t_name + " input does not exist in the graph") - - input_t = graph.get_operation_by_name(input_t_name).outputs[0] - - # Generate random tensor as input - inp_shape = input_t.shape.as_list() - if None in inp_shape: - if input_shape == []: - sys.exit("Please supply shape for the input tensor as it is parametric (? dim) for this model. See --help.") + op_list = [i.name for i in graph.get_operations()] + return tensor_name in op_list + + +def gen_random_input( + model_fname, input_t_name, scaling_factor, input_shape, dump_numpy +): + if not model_fname.endswith(".pb"): + sys.exit("Please supply a valid tensorflow protobuf model (.pb extension)") else: - inp_shape = input_shape - rand_inp_t = np.random.rand(*inp_shape) - (chunk_inp, cnt) = common.numpy_float_array_to_fixed_point_val_str(rand_inp_t, scaling_factor) - - model_dir = os.path.realpath(os.path.dirname(model_fname)) - os.chdir(model_dir) - f = open(model_name + '_input_fixedpt_scale_' + str(scaling_factor) + '.inp', 'w') - f.write(chunk_inp) - f.close() - if dump_numpy: - rand_inp_t.dump(model_name + '_input_fixedpt_scale_' + str(scaling_factor) + '.npy') - return + model_name = os.path.basename(model_fname)[:-3] + print("Loading processed tf graph ", model_fname) + graph = load_pb(model_fname) + + if not check_operation_exists(graph, input_t_name): + sys.exit(input_t_name + " input does not exist in the graph") + + input_t = graph.get_operation_by_name(input_t_name).outputs[0] + + # Generate random tensor as input + inp_shape = input_t.shape.as_list() + if None in inp_shape: + if input_shape == []: + sys.exit( + "Please supply shape for the input tensor as it is parametric (? dim) for this model. See --help." + ) + else: + inp_shape = input_shape + rand_inp_t = np.random.rand(*inp_shape) + (chunk_inp, cnt) = common.numpy_float_array_to_fixed_point_val_str( + rand_inp_t, scaling_factor + ) + + model_dir = os.path.realpath(os.path.dirname(model_fname)) + os.chdir(model_dir) + f = open(model_name + "_input_fixedpt_scale_" + str(scaling_factor) + ".inp", "w") + f.write(chunk_inp) + f.close() + if dump_numpy: + rand_inp_t.dump( + model_name + "_input_fixedpt_scale_" + str(scaling_factor) + ".npy" + ) + return + def boolean_string(s): - if s not in {'False', 'True'}: - raise ValueError('Not a valid boolean string') - return s == 'True' + if s not in {"False", "True"}: + raise ValueError("Not a valid boolean string") + return s == "True" + def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--modelName", required=True, type=str, help="Name of processed tensorflow model (mpc_processed*.pb)") - parser.add_argument("--inputTensorName", required=True, type=str, help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)") - parser.add_argument("--sf", default=12, type=int, help="scaling factor (int)") - parser.add_argument("--inputTensorShape", type=str, default='', help="Comma separated list of shape for input tensor. eg: \"2,245,234,3\"") - parser.add_argument("--dumpNumpy", type=boolean_string, default=False, help="Dump model weights in fixedpt {True/False}") - args = parser.parse_args() - return args + parser = argparse.ArgumentParser() + parser.add_argument( + "--modelName", + required=True, + type=str, + help="Name of processed tensorflow model (mpc_processed*.pb)", + ) + parser.add_argument( + "--inputTensorName", + required=True, + type=str, + help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)", + ) + parser.add_argument("--sf", default=12, type=int, help="scaling factor (int)") + parser.add_argument( + "--inputTensorShape", + type=str, + default="", + help='Comma separated list of shape for input tensor. eg: "2,245,234,3"', + ) + parser.add_argument( + "--dumpNumpy", + type=boolean_string, + default=False, + help="Dump model weights in fixedpt {True/False}", + ) + args = parser.parse_args() + return args + def get_shape_list(shape_string): - if shape_string == '': - return [] - return [int(i) for i in shape_string.split(",")] - -if __name__ == '__main__': - args = parse_args() - shape_list = get_shape_list(args.inputTensorShape) - gen_random_input(args.modelName, args.inputTensorName, args.sf, shape_list, args.dumpNumpy) + if shape_string == "": + return [] + return [int(i) for i in shape_string.split(",")] + + +if __name__ == "__main__": + args = parse_args() + shape_list = get_shape_list(args.inputTensorShape) + gen_random_input( + args.modelName, args.inputTensorName, args.sf, shape_list, args.dumpNumpy + ) diff --git a/Athos/CompilerScripts/get_pred_tf_graph.py b/Athos/CompilerScripts/get_pred_tf_graph.py index 00c204ed..e889184a 100644 --- a/Athos/CompilerScripts/get_pred_tf_graph.py +++ b/Athos/CompilerScripts/get_pred_tf_graph.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import tensorflow as tf import numpy as np import argparse @@ -30,75 +30,115 @@ import os.path import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'TFCompiler')) + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "TFCompiler")) import DumpTFMtData from os import path + def check_operation_exists(graph, tensor_name): - op_list = [i.name for i in graph.get_operations()] - return tensor_name in op_list + op_list = [i.name for i in graph.get_operations()] + return tensor_name in op_list + def numpy_float_array_to_float_val_str(input_array): - chunk = '' - for val in np.nditer(input_array): - chunk += str(val) + '\n' - return chunk + chunk = "" + for val in np.nditer(input_array): + chunk += str(val) + "\n" + return chunk + def compile(model_fname, input_t_name, output_t_name, input_np_arr, output_fname): - if not model_fname.endswith('.pb'): - sys.exit("Please supply a valid tensorflow protobuf model (.pb extension)") - elif not "mpc_processed_" in model_fname: - sys.exit("""Please process model using preprocess_frozen_tf_graph.py. + if not model_fname.endswith(".pb"): + sys.exit("Please supply a valid tensorflow protobuf model (.pb extension)") + elif not "mpc_processed_" in model_fname: + sys.exit( + """Please process model using preprocess_frozen_tf_graph.py. This will optimise it and generate a new .pb with mpc_processed prefix. -Use that with this script.""") - else: - model_name = os.path.basename(model_fname)[:-3] - - print("Loading processed tf graph ", model_fname) - graph = load_pb(model_fname) - - if not check_operation_exists(graph, input_t_name): - sys.exit(input_t_name + " input does not exist in the graph") - if not check_operation_exists(graph, output_t_name): - sys.exit(output_t_name + " output does not exist in the graph") - if not os.path.isfile(input_np_arr): - sys.exit(input_np_arr + " file does not exist.") - - input_t = graph.get_operation_by_name(input_t_name).outputs[0] - output_t = graph.get_operation_by_name(output_t_name).outputs[0] - - np_input_t = np.load(input_np_arr, allow_pickle=True) - - feed_dict = {input_t: np_input_t} - with graph.as_default(): - with tf.Session() as sess: - # Run initializers generated by preprocessing - if check_operation_exists(graph, 'init_constvars'): - sess.run(graph.get_operation_by_name('init_constvars')) - else: - sess.run(tf.global_variables_initializer()) - model_dir = os.path.realpath(os.path.dirname(model_fname)) - os.chdir(model_dir) - output = sess.run(output_t, feed_dict) - with open(output_fname, 'w') as f: - f.write(numpy_float_array_to_float_val_str(output)) +Use that with this script.""" + ) + else: + model_name = os.path.basename(model_fname)[:-3] + + print("Loading processed tf graph ", model_fname) + graph = load_pb(model_fname) + + if not check_operation_exists(graph, input_t_name): + sys.exit(input_t_name + " input does not exist in the graph") + if not check_operation_exists(graph, output_t_name): + sys.exit(output_t_name + " output does not exist in the graph") + if not os.path.isfile(input_np_arr): + sys.exit(input_np_arr + " file does not exist.") + + input_t = graph.get_operation_by_name(input_t_name).outputs[0] + output_t = graph.get_operation_by_name(output_t_name).outputs[0] + + np_input_t = np.load(input_np_arr, allow_pickle=True) + + feed_dict = {input_t: np_input_t} + with graph.as_default(): + with tf.Session() as sess: + # Run initializers generated by preprocessing + if check_operation_exists(graph, "init_constvars"): + sess.run(graph.get_operation_by_name("init_constvars")) + else: + sess.run(tf.global_variables_initializer()) + model_dir = os.path.realpath(os.path.dirname(model_fname)) + os.chdir(model_dir) + output = sess.run(output_t, feed_dict) + with open(output_fname, "w") as f: + f.write(numpy_float_array_to_float_val_str(output)) + def boolean_string(s): - if s not in {'False', 'True'}: - raise ValueError('Not a valid boolean string') - return s == 'True' + if s not in {"False", "True"}: + raise ValueError("Not a valid boolean string") + return s == "True" + def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--modelName", required=True, type=str, help="Name of processed tensorflow model (mpc_processed*.pb)") - parser.add_argument("--inputTensorName", required=True, type=str, help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)") - parser.add_argument("--outputTensorName", required=True, type=str, help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)") - parser.add_argument("--inputTensorNumpyArr", required=True, type=str, help="Name of the input tensor numpy array file for the model.") - parser.add_argument("--outputFileName", required=True, type=str, help="Name of the output file to store the prediction.") - args = parser.parse_args() - return args - -if __name__ == '__main__': - args = parse_args() - compile(args.modelName, args.inputTensorName, args.outputTensorName, args.inputTensorNumpyArr, args.outputFileName) + parser = argparse.ArgumentParser() + parser.add_argument( + "--modelName", + required=True, + type=str, + help="Name of processed tensorflow model (mpc_processed*.pb)", + ) + parser.add_argument( + "--inputTensorName", + required=True, + type=str, + help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)", + ) + parser.add_argument( + "--outputTensorName", + required=True, + type=str, + help="Name of the input tensor for the model. (Op name, dont add '/:0' suffix)", + ) + parser.add_argument( + "--inputTensorNumpyArr", + required=True, + type=str, + help="Name of the input tensor numpy array file for the model.", + ) + parser.add_argument( + "--outputFileName", + required=True, + type=str, + help="Name of the output file to store the prediction.", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + compile( + args.modelName, + args.inputTensorName, + args.outputTensorName, + args.inputTensorNumpyArr, + args.outputFileName, + ) diff --git a/Athos/CompilerScripts/grappler.py b/Athos/CompilerScripts/grappler.py index 65b2faa5..32815d5f 100644 --- a/Athos/CompilerScripts/grappler.py +++ b/Athos/CompilerScripts/grappler.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import os os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @@ -36,177 +36,179 @@ def get_graph_from(graph_def): - with tf.Graph().as_default() as graph: - tf.import_graph_def(graph_def, name="") - return graph + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name="") + return graph def get_default_config(): - c = tf_optimizer.config_pb2.ConfigProto() - optimizer_opts = c.graph_options.rewrite_options - OFF = RewriterConfig.Toggle.Value("OFF") - optimizer_opts.layout_optimizer = OFF - optimizer_opts.implementation_selector = OFF - optimizer_opts.min_graph_nodes = -1 - optimizer_opts.meta_optimizer_iterations = 2 - optimizer_opts.memory_optimization = RewriterConfig.MemOptType.Value("NO_MEM_OPT") - return c + c = tf_optimizer.config_pb2.ConfigProto() + optimizer_opts = c.graph_options.rewrite_options + OFF = RewriterConfig.Toggle.Value("OFF") + optimizer_opts.layout_optimizer = OFF + optimizer_opts.implementation_selector = OFF + optimizer_opts.min_graph_nodes = -1 + optimizer_opts.meta_optimizer_iterations = 2 + optimizer_opts.memory_optimization = RewriterConfig.MemOptType.Value("NO_MEM_OPT") + return c def get_only_prune_config(): - c = get_default_config() - optimizer_opts = c.graph_options.rewrite_options - OFF = RewriterConfig.Toggle.Value("OFF") - optimizer_opts.constant_folding = OFF - optimizer_opts.shape_optimization = OFF - optimizer_opts.remapping = OFF - optimizer_opts.arithmetic_optimization = OFF - optimizer_opts.dependency_optimization = OFF - optimizer_opts.loop_optimization = OFF - optimizer_opts.function_optimization = OFF - optimizer_opts.meta_optimizer_iterations = 1 - return c + c = get_default_config() + optimizer_opts = c.graph_options.rewrite_options + OFF = RewriterConfig.Toggle.Value("OFF") + optimizer_opts.constant_folding = OFF + optimizer_opts.shape_optimization = OFF + optimizer_opts.remapping = OFF + optimizer_opts.arithmetic_optimization = OFF + optimizer_opts.dependency_optimization = OFF + optimizer_opts.loop_optimization = OFF + optimizer_opts.function_optimization = OFF + optimizer_opts.meta_optimizer_iterations = 1 + return c def get_white_list(graph): - transp_perm_ops = set( - i.inputs[1].op.name for i in graph.get_operations() if i.type == "Transpose" - ) - padding_ops = set( - i.inputs[1].op.name for i in graph.get_operations() if i.type == "Pad" - ) - slice_begin_ops = set( - i.inputs[1].op.name for i in graph.get_operations() if i.type == "Slice" - ) - slice_size_ops = set( - i.inputs[2].op.name for i in graph.get_operations() if i.type == "Slice" - ) - mean_axes_ops = set( - i.inputs[1].op.name for i in graph.get_operations() if i.type == "Mean" - ) - sum_axes_ops = set( - i.inputs[1].op.name for i in graph.get_operations() if i.type == "Sum" - ) - split_dim_ops = set( - i.inputs[0].op.name for i in graph.get_operations() if i.type == "Split" - ) - concat_axes_ops = set( - i.inputs[2].op.name - for i in graph.get_operations() - if i.type == "ConcatV2" or i.type == "Concat" - ) - argmax_axes_ops = set( - i.inputs[1].op.name for i in graph.get_operations() if i.type == "ArgMax" - ) - divisor_ops = set( - i.inputs[1].op.name for i in graph.get_operations() if i.type in ["FloorDiv", "RealDiv"] - ) - - white_list = ( - transp_perm_ops - | padding_ops - | slice_begin_ops - | slice_size_ops - | mean_axes_ops - | sum_axes_ops - | split_dim_ops - | concat_axes_ops - | argmax_axes_ops - | divisor_ops - ) - return list(white_list) - - -def optimize(g, inputs, outputs): - sd = SignatureDef() - for name in inputs: - input_t = g.get_operation_by_name(name).outputs[0] - sd.inputs[name].name = name - sd.inputs[name].dtype = input_t.dtype.as_datatype_enum - sd.inputs[name].tensor_shape.CopyFrom(input_t.shape.as_proto()) - for name in outputs: - output_t = g.get_operation_by_name(name).outputs[0] - sd.outputs[name].name = name - sd.outputs[name].dtype = output_t.dtype.as_datatype_enum - sd.outputs[name].tensor_shape.CopyFrom(output_t.shape.as_proto()) - - tf.compat.v1.enable_resource_variables() - cl = cluster.Cluster(disable_detailed_stats=True) - - # We have to run this twice to eliminate constants that are left after - # optimising away split/pad/transpose nodes. They are const parameters like - # axis, perm. They remain after 1 iter of optimization because we specify them - # in the whitelist - for i in range(2): - if i == 0: - graph = g - c = get_default_config() - else: - graph = get_graph_from(optimized_graph_def) - c = get_only_prune_config() - - white_list = get_white_list(graph) - for name in white_list: - graph.add_to_collection( - GraphKeys.TRAIN_OP, graph.get_operation_by_name(name) - ) - - meta_graph = tf.compat.v1.train.export_meta_graph( - graph_def=graph.as_graph_def(), graph=graph + transp_perm_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "Transpose" + ) + padding_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "Pad" + ) + slice_begin_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "Slice" + ) + slice_size_ops = set( + i.inputs[2].op.name for i in graph.get_operations() if i.type == "Slice" + ) + mean_axes_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "Mean" + ) + sum_axes_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "Sum" + ) + split_dim_ops = set( + i.inputs[0].op.name for i in graph.get_operations() if i.type == "Split" + ) + concat_axes_ops = set( + i.inputs[2].op.name + for i in graph.get_operations() + if i.type == "ConcatV2" or i.type == "Concat" + ) + argmax_axes_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "ArgMax" + ) + divisor_ops = set( + i.inputs[1].op.name + for i in graph.get_operations() + if i.type in ["FloorDiv", "RealDiv"] ) - meta_graph.signature_def["not_used_key"].CopyFrom(sd) - optimized_graph_def = tf_optimizer.OptimizeGraph( - config_proto=c, metagraph=meta_graph, cluster=cl + white_list = ( + transp_perm_ops + | padding_ops + | slice_begin_ops + | slice_size_ops + | mean_axes_ops + | sum_axes_ops + | split_dim_ops + | concat_axes_ops + | argmax_axes_ops + | divisor_ops ) - # Don't create VarHandleOp, ReadVariableOp, VarIsInitializedOp - # Instead create VariableV2 ops in the future - tf.disable_resource_variables() - return optimized_graph_def + return list(white_list) + + +def optimize(g, inputs, outputs): + sd = SignatureDef() + for name in inputs: + input_t = g.get_operation_by_name(name).outputs[0] + sd.inputs[name].name = name + sd.inputs[name].dtype = input_t.dtype.as_datatype_enum + sd.inputs[name].tensor_shape.CopyFrom(input_t.shape.as_proto()) + for name in outputs: + output_t = g.get_operation_by_name(name).outputs[0] + sd.outputs[name].name = name + sd.outputs[name].dtype = output_t.dtype.as_datatype_enum + sd.outputs[name].tensor_shape.CopyFrom(output_t.shape.as_proto()) + + tf.compat.v1.enable_resource_variables() + cl = cluster.Cluster(disable_detailed_stats=True) + + # We have to run this twice to eliminate constants that are left after + # optimising away split/pad/transpose nodes. They are const parameters like + # axis, perm. They remain after 1 iter of optimization because we specify them + # in the whitelist + for i in range(2): + if i == 0: + graph = g + c = get_default_config() + else: + graph = get_graph_from(optimized_graph_def) + c = get_only_prune_config() + + white_list = get_white_list(graph) + for name in white_list: + graph.add_to_collection( + GraphKeys.TRAIN_OP, graph.get_operation_by_name(name) + ) + + meta_graph = tf.compat.v1.train.export_meta_graph( + graph_def=graph.as_graph_def(), graph=graph + ) + meta_graph.signature_def["not_used_key"].CopyFrom(sd) + + optimized_graph_def = tf_optimizer.OptimizeGraph( + config_proto=c, metagraph=meta_graph, cluster=cl + ) + # Don't create VarHandleOp, ReadVariableOp, VarIsInitializedOp + # Instead create VariableV2 ops in the future + tf.disable_resource_variables() + return optimized_graph_def def delete_nodes(gd, ops): - nodes_to_delete = set(op.name for op in ops) - new_gd = tf.compat.v1.GraphDef() - nodes_to_keep = [] - for n in gd.node: - if not n.name in nodes_to_delete: - nodes_to_keep.append(n) - new_gd.node.extend(nodes_to_keep) - return new_gd + nodes_to_delete = set(op.name for op in ops) + new_gd = tf.compat.v1.GraphDef() + nodes_to_keep = [] + for n in gd.node: + if not n.name in nodes_to_delete: + nodes_to_keep.append(n) + new_gd.node.extend(nodes_to_keep) + return new_gd def convert_consts_to_var(graph_def): - graph = get_graph_from(graph_def) - all_const_ops = set(i.name for i in graph.get_operations() if i.type == "Const") - const_names_list = list(all_const_ops - set(get_white_list(graph))) - const_var_names_pairs = [] - ops_to_delete = [] - with graph.as_default(): - preexisting_vars = [ - tf.get_variable(i.name, i.outputs[0].shape) - for i in graph.get_operations() - if i.type == "VariableV2" or i.type == "Variable" - ] - - var_list = [] - for name in const_names_list: - tensor = graph.get_operation_by_name(name).outputs[0] - with tf.compat.v1.Session() as sess: - t_value = sess.run(tensor) - t_name = "{}_mpc_const_var".format(name) - var = tf.compat.v1.Variable(t_value, name=t_name) - var_read_op_name = var.to_proto().snapshot_name[:-2] - const_var_names_pairs.append((name, var_read_op_name)) - var_list.append(var) - - for const_name, var_read_name in const_var_names_pairs: - const_op = graph.get_operation_by_name(const_name) - var_op = graph.get_operation_by_name(var_read_name) - ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_op)) - ops_to_delete.append(const_op) - - tf.compat.v1.variables_initializer( - var_list + preexisting_vars, "init_constvars" - ) - return delete_nodes(graph.as_graph_def(), ops_to_delete) + graph = get_graph_from(graph_def) + all_const_ops = set(i.name for i in graph.get_operations() if i.type == "Const") + const_names_list = list(all_const_ops - set(get_white_list(graph))) + const_var_names_pairs = [] + ops_to_delete = [] + with graph.as_default(): + preexisting_vars = [ + tf.get_variable(i.name, i.outputs[0].shape) + for i in graph.get_operations() + if i.type == "VariableV2" or i.type == "Variable" + ] + + var_list = [] + for name in const_names_list: + tensor = graph.get_operation_by_name(name).outputs[0] + with tf.compat.v1.Session() as sess: + t_value = sess.run(tensor) + t_name = "{}_mpc_const_var".format(name) + var = tf.compat.v1.Variable(t_value, name=t_name) + var_read_op_name = var.to_proto().snapshot_name[:-2] + const_var_names_pairs.append((name, var_read_op_name)) + var_list.append(var) + + for const_name, var_read_name in const_var_names_pairs: + const_op = graph.get_operation_by_name(const_name) + var_op = graph.get_operation_by_name(var_read_name) + ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_op)) + ops_to_delete.append(const_op) + + tf.compat.v1.variables_initializer( + var_list + preexisting_vars, "init_constvars" + ) + return delete_nodes(graph.as_graph_def(), ops_to_delete) diff --git a/Athos/CompilerScripts/parse_config.py b/Athos/CompilerScripts/parse_config.py index 22624ce4..68a0f5bc 100644 --- a/Athos/CompilerScripts/parse_config.py +++ b/Athos/CompilerScripts/parse_config.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import argparse import os.path import json @@ -55,169 +55,170 @@ def get_config(config_path): - if not os.path.isfile(config_path): - sys.exit("Config file specified does not exist") - with open(config_path) as f: - try: - config = json.load(f) - except JSONDecodeError as e: - sys.exit( - "Error while parsing the config json:\n" - + e.msg - + " at line no. " - + str(e.lineno) - ) - return config + if not os.path.isfile(config_path): + sys.exit("Config file specified does not exist") + with open(config_path) as f: + try: + config = json.load(f) + except JSONDecodeError as e: + sys.exit( + "Error while parsing the config json:\n" + + e.msg + + " at line no. " + + str(e.lineno) + ) + return config def get_str_param(config, p_name): - p = config.get(p_name) - if p is None: - sys.exit(p_name + " not specified in config.") - assert type(p) == str, p_name + " is not a string" - return p + p = config.get(p_name) + if p is None: + sys.exit(p_name + " not specified in config.") + assert type(p) == str, p_name + " is not a string" + return p def get_opt_str_param(config, p_name): - p = config.get(p_name) - if p is None: + p = config.get(p_name) + if p is None: + return p + assert type(p) == str, p_name + " is not a string" return p - assert type(p) == str, p_name + " is not a string" - return p def get_bool_param(config, p_name): - p = config.get(p_name) - if p is None: - sys.exit(p_name + " not specified in config.") - assert type(p) == bool, p_name + " is not a boolean" - return p + p = config.get(p_name) + if p is None: + sys.exit(p_name + " not specified in config.") + assert type(p) == bool, p_name + " is not a boolean" + return p def get_opt_bool_param(config, p_name): - p = config.get(p_name) - if p is None: + p = config.get(p_name) + if p is None: + return p + assert type(p) == bool, p_name + " is not a boolean" return p - assert type(p) == bool, p_name + " is not a boolean" - return p def get_int_param(config, p_name): - p = config.get(p_name) - if p is None: - sys.exit(p_name + " not specified in config.") - assert type(p) == int, p_name + " is not an integer" - return p + p = config.get(p_name) + if p is None: + sys.exit(p_name + " not specified in config.") + assert type(p) == int, p_name + " is not an integer" + return p def get_opt_int_param(config, p_name): - p = config.get(p_name) - if p is None: + p = config.get(p_name) + if p is None: + return p + assert type(p) == int, p_name + " is not an integer" return p - assert type(p) == int, p_name + " is not an integer" - return p def get_str_list_param(config, p_name): - p = config.get(p_name) - if p is None: - sys.exit(p_name + " not specified in config.") - assert type(p) == list, p_name + "is not a list of strings" - for i in p: - assert type(i) == str, p_name + "is not a list of strings" - return p + p = config.get(p_name) + if p is None: + sys.exit(p_name + " not specified in config.") + assert type(p) == list, p_name + "is not a list of strings" + for i in p: + assert type(i) == str, p_name + "is not a list of strings" + return p + def get_opt_str_list_param(config, p_name): - p = config.get(p_name) - if p is None: + p = config.get(p_name) + if p is None: + return p + assert type(p) == list, p_name + "is not a list of strings" + for i in p: + assert type(i) == str, p_name + "is not a list of strings" return p - assert type(p) == list, p_name + "is not a list of strings" - for i in p: - assert type(i) == str, p_name + "is not a list of strings" - return p def get_opt_param(config, p_name): - p = config.get(p_name) - return p + p = config.get(p_name) + return p def get_shape_list(shape_string): - shape = [] - if shape_string == "": + shape = [] + if shape_string == "": + return shape + for i in shape_string.split(","): + assert i.isnumeric(), "Given input shape has non-integer value : {}".format(i) + shape.append(int(i)) return shape - for i in shape_string.split(","): - assert i.isnumeric(), "Given input shape has non-integer value : {}".format(i) - shape.append(int(i)) - return shape def parse_input_tensors(config): - input_t_info = {} - p = config.get("input_tensors") - if p is None: + input_t_info = {} + p = config.get("input_tensors") + if p is None: + return input_t_info + assert type(p) == dict, "Input tensors should be a dict of name=>shape" + for name, shape_str in p.items(): + input_t_info[name] = get_shape_list(shape_str) return input_t_info - assert type(p) == dict, "Input tensors should be a dict of name=>shape" - for name, shape_str in p.items(): - input_t_info[name] = get_shape_list(shape_str) - return input_t_info def parse_config(config, sample_network=False): - if not sample_network: - model_fname = get_str_param(config, "model_name") - if not model_fname.endswith(".pb"): - sys.exit( - model_fname - + " is not a tensorflow protobuf file. Please supply " - + "a valid tensorflow protobuf model (.pb extension)" - ) - if not os.path.isfile(model_fname): - sys.exit(model_fname + " file does not exist") - else: - network_name = get_str_param(config, "network_name") - run_in_tmux = get_opt_bool_param(config, "run_in_tmux") - - target = get_str_param(config, "target").upper() - - output_tensors = get_opt_str_list_param(config, "output_tensors") - input_t_info = parse_input_tensors(config) - - save_weights = get_opt_bool_param(config, "save_weights") - scale = get_opt_int_param(config, "scale") - bitlen = get_opt_int_param(config, "bitlength") - modulo = get_opt_int_param(config, "modulo") - backend = get_opt_str_param(config, "backend") - disable_hlil_opts = get_opt_bool_param(config, "disable_all_hlil_opts") - disable_rmo = get_opt_bool_param(config, "disable_relu_maxpool_opts") - disable_garbage_collection = get_opt_bool_param( - config, "disable_garbage_collection" - ) - disable_trunc = get_opt_bool_param(config, "disable_trunc_opts") - - params = { - "input_tensors": input_t_info, - "output_tensors": output_tensors, - "scale": scale, - "bitlength": bitlen, - "target": target, - "save_weights": save_weights, - "modulo": modulo, - "backend": backend, - "disable_all_hlil_opts": disable_hlil_opts, - "disable_relu_maxpool_opts": disable_rmo, - "disable_garbage_collection": disable_garbage_collection, - "disable_trunc_opts": disable_trunc, - } - if sample_network: - params["network_name"] = network_name - params["run_in_tmux"] = run_in_tmux - else: - params["model_name"] = model_fname - return params + if not sample_network: + model_fname = get_str_param(config, "model_name") + if not model_fname.endswith(".pb"): + sys.exit( + model_fname + + " is not a tensorflow protobuf file. Please supply " + + "a valid tensorflow protobuf model (.pb extension)" + ) + if not os.path.isfile(model_fname): + sys.exit(model_fname + " file does not exist") + else: + network_name = get_str_param(config, "network_name") + run_in_tmux = get_opt_bool_param(config, "run_in_tmux") + + target = get_str_param(config, "target").upper() + + output_tensors = get_opt_str_list_param(config, "output_tensors") + input_t_info = parse_input_tensors(config) + + save_weights = get_opt_bool_param(config, "save_weights") + scale = get_opt_int_param(config, "scale") + bitlen = get_opt_int_param(config, "bitlength") + modulo = get_opt_int_param(config, "modulo") + backend = get_opt_str_param(config, "backend") + disable_hlil_opts = get_opt_bool_param(config, "disable_all_hlil_opts") + disable_rmo = get_opt_bool_param(config, "disable_relu_maxpool_opts") + disable_garbage_collection = get_opt_bool_param( + config, "disable_garbage_collection" + ) + disable_trunc = get_opt_bool_param(config, "disable_trunc_opts") + + params = { + "input_tensors": input_t_info, + "output_tensors": output_tensors, + "scale": scale, + "bitlength": bitlen, + "target": target, + "save_weights": save_weights, + "modulo": modulo, + "backend": backend, + "disable_all_hlil_opts": disable_hlil_opts, + "disable_relu_maxpool_opts": disable_rmo, + "disable_garbage_collection": disable_garbage_collection, + "disable_trunc_opts": disable_trunc, + } + if sample_network: + params["network_name"] = network_name + params["run_in_tmux"] = run_in_tmux + else: + params["model_name"] = model_fname + return params def get_params(config_fname, sample_network=False): - config = get_config(config_fname) - params = parse_config(config, sample_network) - return params + config = get_config(config_fname) + params = parse_config(config, sample_network) + return params diff --git a/Athos/CompilerScripts/preprocess_frozen_tf_graph.py b/Athos/CompilerScripts/preprocess_frozen_tf_graph.py index c2d7cb73..462723cf 100644 --- a/Athos/CompilerScripts/preprocess_frozen_tf_graph.py +++ b/Athos/CompilerScripts/preprocess_frozen_tf_graph.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,12 +20,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" from tf_graph_io import * from tf_graph_trans import * import sys import time -import os +import os import argparse import os.path @@ -33,71 +33,102 @@ # Transpose nodes require perm as compile time constants for parametric codegen # So we don't eliminate the constants we need dring compile time def get_const_names(graph): - transp_perm_ops = set(i.inputs[1].op.name for i in graph.get_operations() if i.type == 'Transpose') - padding_ops = set(i.inputs[1].op.name for i in graph.get_operations() if i.type == 'Pad') - slice_begin_ops = set(i.inputs[1].op.name for i in graph.get_operations() if i.type == 'Slice') - slice_size_ops = set(i.inputs[2].op.name for i in graph.get_operations() if i.type == 'Slice') - mean_axes_ops = set(i.inputs[1].op.name for i in graph.get_operations() if i.type == 'Mean') - split_dim_ops = set(i.inputs[0].op.name for i in graph.get_operations() if i.type == 'Split') - white_list = transp_perm_ops | padding_ops | slice_begin_ops | slice_size_ops | mean_axes_ops | split_dim_ops - all_const_ops = set(i.name for i in graph.get_operations() if i.type == 'Const') - return list(all_const_ops - white_list) + transp_perm_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "Transpose" + ) + padding_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "Pad" + ) + slice_begin_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "Slice" + ) + slice_size_ops = set( + i.inputs[2].op.name for i in graph.get_operations() if i.type == "Slice" + ) + mean_axes_ops = set( + i.inputs[1].op.name for i in graph.get_operations() if i.type == "Mean" + ) + split_dim_ops = set( + i.inputs[0].op.name for i in graph.get_operations() if i.type == "Split" + ) + white_list = ( + transp_perm_ops + | padding_ops + | slice_begin_ops + | slice_size_ops + | mean_axes_ops + | split_dim_ops + ) + all_const_ops = set(i.name for i in graph.get_operations() if i.type == "Const") + return list(all_const_ops - white_list) + def check_operation_exists(graph, tensor_name): - op_list = [i.name for i in graph.get_operations()] - return tensor_name in op_list + op_list = [i.name for i in graph.get_operations()] + return tensor_name in op_list + def optimize(input_fname, output_t_name): - if not input_fname.endswith('.pb'): - sys.exit("Please supply a valid tensorflow protobuf model (.pb extension)") - - actual_fname = os.path.basename(input_fname) - dirname = os.path.dirname(input_fname) - output_fname = os.path.join(dirname, "mpc_processed_" + actual_fname) - print("Loading ", input_fname, "for processing.") - graph = load_pb(input_fname) - - if not check_operation_exists(graph, output_t_name): - sys.exit(output_t_name + " output does not exist in the graph") - input_names = [i.name for i in graph.get_operations() if i.type=="Placeholder"] - - #graph = remove_dead_nodes(graph, input_names, [output_t_name]) - - print("\n\nThis process will take some time to run as we execute portions of the graph.\n\n") - time.sleep(1) - # Fold away all static computations - - print("Running fold splits") - graph = fold_splits(graph) - print(graph.get_operations(),end="\n\n") - print("Running constant folding") - graph = fold_constants(graph) - - # Convert constants to variables so as to separate the data and the generated code - # Otherwise huge arrays will show up as constants in the generated code, thereby - # increasing binary size. - print("Convert frozen constants to variables") - graph = convert_consts_to_var(graph, get_const_names(graph)) - - - input_names = [i.name for i in graph.get_operations() if i.type=="Placeholder"] - #graph = remove_dead_nodes(graph, input_names, [output_t_name]) - - # At this stage the graph still has constants embedded in it - # in the assign nodes for variables. We cannot execute the graph without - # these constants. However after inferring the size, we can call remove_dead_nodes - # to optimize away the constants and assign nodes and make the graph amenable - # for codegen - dump_pb(graph, output_fname) - print("The processed graph is dumped in ", output_fname) + if not input_fname.endswith(".pb"): + sys.exit("Please supply a valid tensorflow protobuf model (.pb extension)") + + actual_fname = os.path.basename(input_fname) + dirname = os.path.dirname(input_fname) + output_fname = os.path.join(dirname, "mpc_processed_" + actual_fname) + print("Loading ", input_fname, "for processing.") + graph = load_pb(input_fname) + + if not check_operation_exists(graph, output_t_name): + sys.exit(output_t_name + " output does not exist in the graph") + input_names = [i.name for i in graph.get_operations() if i.type == "Placeholder"] + + # graph = remove_dead_nodes(graph, input_names, [output_t_name]) + + print( + "\n\nThis process will take some time to run as we execute portions of the graph.\n\n" + ) + time.sleep(1) + # Fold away all static computations + + print("Running fold splits") + graph = fold_splits(graph) + print(graph.get_operations(), end="\n\n") + print("Running constant folding") + graph = fold_constants(graph) + + # Convert constants to variables so as to separate the data and the generated code + # Otherwise huge arrays will show up as constants in the generated code, thereby + # increasing binary size. + print("Convert frozen constants to variables") + graph = convert_consts_to_var(graph, get_const_names(graph)) + + input_names = [i.name for i in graph.get_operations() if i.type == "Placeholder"] + # graph = remove_dead_nodes(graph, input_names, [output_t_name]) + + # At this stage the graph still has constants embedded in it + # in the assign nodes for variables. We cannot execute the graph without + # these constants. However after inferring the size, we can call remove_dead_nodes + # to optimize away the constants and assign nodes and make the graph amenable + # for codegen + dump_pb(graph, output_fname) + print("The processed graph is dumped in ", output_fname) + def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--modelName", required=True, type=str, help="Name of tensorflow model (*.pb)") - parser.add_argument("--outputTensorName", required=True, type=str, help="Name of the output tensor for the model. (Op name, dont add '/:0' suffix)") - args = parser.parse_args() - return args - -if __name__ == '__main__': - args = parse_args() - optimize(args.modelName, args.outputTensorName) + parser = argparse.ArgumentParser() + parser.add_argument( + "--modelName", required=True, type=str, help="Name of tensorflow model (*.pb)" + ) + parser.add_argument( + "--outputTensorName", + required=True, + type=str, + help="Name of the output tensor for the model. (Op name, dont add '/:0' suffix)", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + optimize(args.modelName, args.outputTensorName) diff --git a/Athos/CompilerScripts/tf_graph_io.py b/Athos/CompilerScripts/tf_graph_io.py index c99d97d6..e23cc2d1 100644 --- a/Athos/CompilerScripts/tf_graph_io.py +++ b/Athos/CompilerScripts/tf_graph_io.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,42 +20,49 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import tensorflow as tf from tensorflow.python.platform import gfile + def display_graph(graph, tensorboard_log_dir): - writer = tf.summary.FileWriter(tensorboard_log_dir, graph) - writer.close() + writer = tf.summary.FileWriter(tensorboard_log_dir, graph) + writer.close() + def load_pb(path_to_pb): - with tf.io.gfile.GFile(path_to_pb, 'rb') as f: - graph_def = tf.compat.v1.GraphDef() - graph_def.ParseFromString(f.read()) - with tf.Graph().as_default() as graph: - tf.import_graph_def(graph_def, name="") - return graph + with tf.io.gfile.GFile(path_to_pb, "rb") as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name="") + return graph + def dump_pb(graph, filename): - with tf.io.gfile.GFile(filename, 'wb') as f: - graph_def = graph.as_graph_def() - f.write(graph_def.SerializeToString()) + with tf.io.gfile.GFile(filename, "wb") as f: + graph_def = graph.as_graph_def() + f.write(graph_def.SerializeToString()) + def dump_graph_def_pb(graph_def, filename): - with tf.io.gfile.GFile(filename, 'wb') as f: - f.write(graph_def.SerializeToString()) + with tf.io.gfile.GFile(filename, "wb") as f: + f.write(graph_def.SerializeToString()) + def dump_pb_without_vars(graph, output_names, filename): - with tf.io.gfile.GFile(filename, 'wb') as f: - with tf.Session(graph=graph) as sess: - sess.run(tf.global_variables_initializer()) - graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess, - graph.as_graph_def(), output_names) - f.write(graph_def.SerializeToString()) + with tf.io.gfile.GFile(filename, "wb") as f: + with tf.Session(graph=graph) as sess: + sess.run(tf.global_variables_initializer()) + graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( + sess, graph.as_graph_def(), output_names + ) + f.write(graph_def.SerializeToString()) + def save_model(graph, model_name): - with graph.as_default(): - with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - save_path = tf.train.Saver().save(sess, model_name) - print("Model saved in path: %s" % save_path) + with graph.as_default(): + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + save_path = tf.train.Saver().save(sess, model_name) + print("Model saved in path: %s" % save_path) diff --git a/Athos/CompilerScripts/tf_graph_trans.py b/Athos/CompilerScripts/tf_graph_trans.py index 6cd90c2e..99bd2704 100644 --- a/Athos/CompilerScripts/tf_graph_trans.py +++ b/Athos/CompilerScripts/tf_graph_trans.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,153 +20,176 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import tensorflow as tf import tensorflow.contrib.graph_editor as ge from tensorflow.tools.graph_transforms import TransformGraph + def delete_nodes(graph, ops): - gd = graph.as_graph_def() - nodes_to_delete = set(op.name for op in ops) - new_gd = tf.compat.v1.GraphDef() - nodes_to_keep = [] - for n in gd.node: - if not n.name in nodes_to_delete: - nodes_to_keep.append(n) - new_gd.node.extend(nodes_to_keep) - new_graph = tf.Graph() - with new_graph.as_default(): - tf.import_graph_def(new_gd, name="") - return new_graph + gd = graph.as_graph_def() + nodes_to_delete = set(op.name for op in ops) + new_gd = tf.compat.v1.GraphDef() + nodes_to_keep = [] + for n in gd.node: + if not n.name in nodes_to_delete: + nodes_to_keep.append(n) + new_gd.node.extend(nodes_to_keep) + new_graph = tf.Graph() + with new_graph.as_default(): + tf.import_graph_def(new_gd, name="") + return new_graph + def remove_dead_nodes(graph, in_list, out_list): - transforms = ['remove_nodes(op=Identity)', 'strip_unused_nodes'] - optimized_graph_def = TransformGraph(graph.as_graph_def(), in_list, out_list, transforms) - with tf.Graph().as_default() as opt_graph: - tf.import_graph_def(optimized_graph_def, name="") - return opt_graph + transforms = ["remove_nodes(op=Identity)", "strip_unused_nodes"] + optimized_graph_def = TransformGraph( + graph.as_graph_def(), in_list, out_list, transforms + ) + with tf.Graph().as_default() as opt_graph: + tf.import_graph_def(optimized_graph_def, name="") + return opt_graph + def convert_consts_to_var(graph, const_names_list): - const_var_names_pairs = [] - ops_to_delete = [] - with graph.as_default(): - preexisting_vars = [tf.get_variable(i.name, i.outputs[0].shape) for i in graph.get_operations() if i.type=="VariableV2" or i.type=="Variable"] - - var_list = [] - for name in const_names_list: - tensor = graph.get_operation_by_name(name).outputs[0] - with tf.Session() as sess: - t_value = sess.run(tensor) - t_name = '{}_mpc_const_var'.format(name) - var = tf.Variable(t_value, name=t_name) - const_var_names_pairs.append((name, t_name)) - var_list.append(var) - - for const_name, var_name in const_var_names_pairs: - const_op = graph.get_operation_by_name(const_name) - var_op = graph.get_operation_by_name('{}/read'.format(var_name)) - ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_op)) - ops_to_delete.append(const_op) - - tf.compat.v1.variables_initializer(var_list + preexisting_vars, 'init_constvars') - return delete_nodes(graph, ops_to_delete) + const_var_names_pairs = [] + ops_to_delete = [] + with graph.as_default(): + preexisting_vars = [ + tf.get_variable(i.name, i.outputs[0].shape) + for i in graph.get_operations() + if i.type == "VariableV2" or i.type == "Variable" + ] + + var_list = [] + for name in const_names_list: + tensor = graph.get_operation_by_name(name).outputs[0] + with tf.Session() as sess: + t_value = sess.run(tensor) + t_name = "{}_mpc_const_var".format(name) + var = tf.Variable(t_value, name=t_name) + const_var_names_pairs.append((name, t_name)) + var_list.append(var) + + for const_name, var_name in const_var_names_pairs: + const_op = graph.get_operation_by_name(const_name) + var_op = graph.get_operation_by_name("{}/read".format(var_name)) + ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_op)) + ops_to_delete.append(const_op) + + tf.compat.v1.variables_initializer( + var_list + preexisting_vars, "init_constvars" + ) + return delete_nodes(graph, ops_to_delete) + def get_inputs(op): - return set(input.op for input in op.inputs) + return set(input.op for input in op.inputs) + def replace_node_with_const(node): - print("Trying to execute node {}".format(node.name)) - graph = node.graph - with graph.as_default(): - const_lists = [] - with tf.Session() as sess: - for out_t in node.outputs: - const_val = sess.run(out_t) - const_op = tf.constant(const_val).op - const_lists.append(const_op) - ge.swap_outputs(ge.sgv(node), ge.sgv(const_lists)) + print("Trying to execute node {}".format(node.name)) + graph = node.graph + with graph.as_default(): + const_lists = [] + with tf.Session() as sess: + for out_t in node.outputs: + const_val = sess.run(out_t) + const_op = tf.constant(const_val).op + const_lists.append(const_op) + ge.swap_outputs(ge.sgv(node), ge.sgv(const_lists)) + def DFS(node, visited, const_map, deleted_nodes): - print("Visiting node {}".format(node.name)) - visited.add(node) - if node.type == "Const": - const_map[node.name] = True - return True - if len(node.inputs) == 0: + print("Visiting node {}".format(node.name)) + visited.add(node) + if node.type == "Const": + const_map[node.name] = True + return True + if len(node.inputs) == 0: + const_map[node.name] = False + return False + for inp_node in get_inputs(node): + if not inp_node in visited: + isConst = DFS(inp_node, visited, const_map, deleted_nodes) + const_map[inp_node.name] = isConst + all_inputs_const = True + for inp_node in get_inputs(node): + all_inputs_const = all_inputs_const and const_map[inp_node.name] + if all_inputs_const: + const_map[node.name] = True + replace_node_with_const(node) + deleted_nodes.add(node) + return True const_map[node.name] = False return False - for inp_node in get_inputs(node): - if not inp_node in visited: - isConst = DFS(inp_node, visited, const_map, deleted_nodes) - const_map[inp_node.name] = isConst - all_inputs_const = True - for inp_node in get_inputs(node): - all_inputs_const = all_inputs_const and const_map[inp_node.name] - if all_inputs_const: - const_map[node.name] = True - replace_node_with_const(node) - deleted_nodes.add(node) - return True - const_map[node.name] = False - return False + def get_dangling_consts_old(graph): - consts = [ i for i in graph.get_operations() if i.type == 'Const' ] - def has_users(op): - for i in op.outputs: - for j in i.consumers(): - if j.type != 'Const': - return True - return False - return [ i for i in consts if not has_users(i)] + consts = [i for i in graph.get_operations() if i.type == "Const"] + + def has_users(op): + for i in op.outputs: + for j in i.consumers(): + if j.type != "Const": + return True + return False + + return [i for i in consts if not has_users(i)] + def get_dangling_consts(graph, deleted_nodes): - consts = [ i for i in graph.get_operations() if i.type == 'Const' ] - def has_users(op): - for i in op.outputs: - for j in i.consumers(): - if j.type != 'Const' and (j not in deleted_nodes): - return True - return False - return [ i for i in consts if not has_users(i)] - + consts = [i for i in graph.get_operations() if i.type == "Const"] + + def has_users(op): + for i in op.outputs: + for j in i.consumers(): + if j.type != "Const" and (j not in deleted_nodes): + return True + return False + + return [i for i in consts if not has_users(i)] + + def fold_constants(graph): - visited = set({}) - const_map = {} - deleted_nodes = set({}) - with graph.as_default(): - for node in graph.get_operations(): - if not node in visited: - isConst = DFS(node, visited, const_map, deleted_nodes) - if isConst: - replace_node_with_const(node) - deleted_nodes.add(node) - useless_consts = get_dangling_consts(graph, deleted_nodes) - print("No. of consts to be removed = {}".format(len(useless_consts))) - deleted_nodes.update(useless_consts) - graph = delete_nodes(graph, deleted_nodes) - consts = [ i for i in graph.get_operations() if i.type == 'Const' ] - print("No. of total consts still remaining = {}".format(len(consts))) - dang_consts = get_dangling_consts_old(graph) - print("No. of dang consts still remaining = {}".format(len(dang_consts))) - return graph + visited = set({}) + const_map = {} + deleted_nodes = set({}) + with graph.as_default(): + for node in graph.get_operations(): + if not node in visited: + isConst = DFS(node, visited, const_map, deleted_nodes) + if isConst: + replace_node_with_const(node) + deleted_nodes.add(node) + useless_consts = get_dangling_consts(graph, deleted_nodes) + print("No. of consts to be removed = {}".format(len(useless_consts))) + deleted_nodes.update(useless_consts) + graph = delete_nodes(graph, deleted_nodes) + consts = [i for i in graph.get_operations() if i.type == "Const"] + print("No. of total consts still remaining = {}".format(len(consts))) + dang_consts = get_dangling_consts_old(graph) + print("No. of dang consts still remaining = {}".format(len(dang_consts))) + return graph + def replace_nodes_with_identity(graph, nop_splits): - with graph.as_default(): - for split in nop_splits: - inp_var = split.inputs[1] - identity = tf.identity(inp_var).op - ge.swap_outputs(ge.sgv(split), ge.sgv(identity)) - return graph - + with graph.as_default(): + for split in nop_splits: + inp_var = split.inputs[1] + identity = tf.identity(inp_var).op + ge.swap_outputs(ge.sgv(split), ge.sgv(identity)) + return graph + + def fold_splits(graph): - with graph.as_default(): - nop_splits = [] - for node in graph.get_operations(): - if node.type != "Split": - continue - if node.get_attr("num_split") == 1: - nop_splits.append(node) - replace_nodes_with_identity(graph, nop_splits) - graph = delete_nodes(graph, set(nop_splits)) - return graph + with graph.as_default(): + nop_splits = [] + for node in graph.get_operations(): + if node.type != "Split": + continue + if node.get_attr("num_split") == 1: + nop_splits.append(node) + replace_nodes_with_identity(graph, nop_splits) + graph = delete_nodes(graph, set(nop_splits)) + return graph diff --git a/Athos/HelperScripts/Confirm_preprocessing.py b/Athos/HelperScripts/Confirm_preprocessing.py index 7489be8d..06d99432 100644 --- a/Athos/HelperScripts/Confirm_preprocessing.py +++ b/Athos/HelperScripts/Confirm_preprocessing.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,31 +20,31 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" -import os,sys, functools +import os, sys, functools -if (len(sys.argv)!=4): - print("Incorrect args. Error.", file=sys.stderr) - exit(1) +if len(sys.argv) != 4: + print("Incorrect args. Error.", file=sys.stderr) + exit(1) preProcessedImagesDir = sys.argv[1] startImgNum = int(sys.argv[2]) endImgNum = int(sys.argv[3]) -numImg = endImgNum-startImgNum+1 -expectedShape = [224,224,3] -expectedNumElements = functools.reduce(lambda x,y : x*y, expectedShape) +numImg = endImgNum - startImgNum + 1 +expectedShape = [224, 224, 3] +expectedNumElements = functools.reduce(lambda x, y: x * y, expectedShape) print("ExpectedNumElements in all preprocessed images = ", expectedNumElements) badImages = [] -for i in range(startImgNum,endImgNum): - if ((i % (numImg/10))==0): - print("Reached i = {0}.".format(i)) - filename = os.path.join(preProcessedImagesDir, 'ImageNum_'+str(i)+'.inp') - with open(filename, 'r') as ff: - line = ff.readline() - val = list(map(lambda x : float(x) , line.split())) - if (len(val)!=expectedNumElements): - print("Expected num of elements not found in imagenum = {0}.".format(i)) - badImages.append(i) +for i in range(startImgNum, endImgNum): + if (i % (numImg / 10)) == 0: + print("Reached i = {0}.".format(i)) + filename = os.path.join(preProcessedImagesDir, "ImageNum_" + str(i) + ".inp") + with open(filename, "r") as ff: + line = ff.readline() + val = list(map(lambda x: float(x), line.split())) + if len(val) != expectedNumElements: + print("Expected num of elements not found in imagenum = {0}.".format(i)) + badImages.append(i) print("Found {0} bad images.".format(len(badImages))) diff --git a/Athos/HelperScripts/Convert_WnId_To_TrainId.py b/Athos/HelperScripts/Convert_WnId_To_TrainId.py index 2e2336c5..f62ff8d9 100644 --- a/Athos/HelperScripts/Convert_WnId_To_TrainId.py +++ b/Athos/HelperScripts/Convert_WnId_To_TrainId.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,23 +20,23 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" -with open('./ImageNet_ValData/imagenet_2012_validation_synset_labels.txt', 'r') as ff: - allLines = ff.readlines() +with open("./ImageNet_ValData/imagenet_2012_validation_synset_labels.txt", "r") as ff: + allLines = ff.readlines() -labels = list(map(lambda x : x.rstrip(), allLines)) +labels = list(map(lambda x: x.rstrip(), allLines)) sortedLabels = sorted(set(labels)) labelsMap = {} -for k,v in enumerate(sortedLabels): - labelsMap[v] = k+1 +for k, v in enumerate(sortedLabels): + labelsMap[v] = k + 1 -with open('./ImageNet_ValData/sortedWordnetIds.txt', 'w') as ff: - for x in sortedLabels: - ff.write(x + '\n') +with open("./ImageNet_ValData/sortedWordnetIds.txt", "w") as ff: + for x in sortedLabels: + ff.write(x + "\n") -with open('./ImageNet_ValData/imagenet12_val_nlabels.txt', 'w') as ff: - for curLabel in labels: - ff.write(str(labelsMap[curLabel])) - ff.write('\n') +with open("./ImageNet_ValData/imagenet12_val_nlabels.txt", "w") as ff: + for curLabel in labels: + ff.write(str(labelsMap[curLabel])) + ff.write("\n") diff --git a/Athos/HelperScripts/FindAccuracy.py b/Athos/HelperScripts/FindAccuracy.py index 4b98eaa3..cf5afc59 100644 --- a/Athos/HelperScripts/FindAccuracy.py +++ b/Athos/HelperScripts/FindAccuracy.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,25 +20,27 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" # Use this script to find accuracy once the full exploration has been done on relevant scale factors. # NOTE: The ground truth labels are in [1, 1000]. -# Resnet outputs labels in [0,1000] -- class 0 is extraneous and is the other category. -# For DenseNet and SqueezeNet, the labels after argmax are in [0,999]: -# so either we have to do ground truth labels-1 or while outputing from the code for SqNet/DenseNet, -# add a +1. Choosing to go with the former. -# So, in summary, when running resnet, use the last parameter as 1, while for SqueezeNet/DenseNet use it as 0. +# Resnet outputs labels in [0,1000] -- class 0 is extraneous and is the other category. +# For DenseNet and SqueezeNet, the labels after argmax are in [0,999]: +# so either we have to do ground truth labels-1 or while outputing from the code for SqNet/DenseNet, +# add a +1. Choosing to go with the former. +# So, in summary, when running resnet, use the last parameter as 1, while for SqueezeNet/DenseNet use it as 0. import os, sys import numpy as np -if (len(sys.argv) < 4): - print("Usage : python3 FindAccuracy.py ") - exit(1) +if len(sys.argv) < 4: + print( + "Usage : python3 FindAccuracy.py " + ) + exit(1) # Change following parameters accordingly -ScalesToCheck = [10] #range(8,31) +ScalesToCheck = [10] # range(8,31) NumProcesses = 32 numImages = 50000 topK = 5 @@ -46,58 +48,74 @@ groundTruthLabelsFileName = sys.argv[1] inferenceOutputDirectory = sys.argv[2] lowerBoundOfOutputLabels = int(sys.argv[3]) -if (lowerBoundOfOutputLabels != 0 and lowerBoundOfOutputLabels != 1): - print("lowerBoundOfOutputLabels should be either 0 or 1. Exiting.", file=sys.stderr) - exit(1) +if lowerBoundOfOutputLabels != 0 and lowerBoundOfOutputLabels != 1: + print("lowerBoundOfOutputLabels should be either 0 or 1. Exiting.", file=sys.stderr) + exit(1) -with open(groundTruthLabelsFileName, 'r') as ff: - groundTruthLabels = ff.readlines() +with open(groundTruthLabelsFileName, "r") as ff: + groundTruthLabels = ff.readlines() + +groundTruthLabels = list( + map(lambda x: int(x.rstrip()), groundTruthLabels) +) # For imagenet, this is in [1,1000] +if lowerBoundOfOutputLabels == 0: + groundTruthLabels = list( + map(lambda x: x - 1, groundTruthLabels) + ) # If the labels in the output start from 0, + # subtract 1 from the ground truth labels. -groundTruthLabels = list(map(lambda x : int(x.rstrip()), groundTruthLabels)) #For imagenet, this is in [1,1000] -if (lowerBoundOfOutputLabels==0): - groundTruthLabels = list(map(lambda x : x-1, groundTruthLabels)) #If the labels in the output start from 0, - # subtract 1 from the ground truth labels. def parseInferenceOutputFile(predictions, outputFileName): - with open(outputFileName, 'r') as ff: - lines = ff.readlines() - lines = list(map(lambda x : x.rstrip(), lines)) - lines = list(filter(lambda x : x!='', lines)) - assert(len(lines)%2==0) - imgCounter = None - for line in lines: - if (line.startswith('Answer for')): - imgCounter = int(line.split('=')[1].split(':')[0]) #This is assumed to be 1-indexed - else: - assert(imgCounter is not None) - preds = line.split() - preds = np.array(list(map(lambda x : int(x), preds))) - topKPredsIdx = np.argpartition(preds, -1*topK)[-1*topK:] - topKPredsIdx = topKPredsIdx[np.argsort(preds[topKPredsIdx])] - for i,val in enumerate(topKPredsIdx): - predictions[imgCounter-1][i] = val + with open(outputFileName, "r") as ff: + lines = ff.readlines() + lines = list(map(lambda x: x.rstrip(), lines)) + lines = list(filter(lambda x: x != "", lines)) + assert len(lines) % 2 == 0 + imgCounter = None + for line in lines: + if line.startswith("Answer for"): + imgCounter = int( + line.split("=")[1].split(":")[0] + ) # This is assumed to be 1-indexed + else: + assert imgCounter is not None + preds = line.split() + preds = np.array(list(map(lambda x: int(x), preds))) + topKPredsIdx = np.argpartition(preds, -1 * topK)[-1 * topK :] + topKPredsIdx = topKPredsIdx[np.argsort(preds[topKPredsIdx])] + for i, val in enumerate(topKPredsIdx): + predictions[imgCounter - 1][i] = val + def calculateAccuracy(predictions): - global groundTruthLabels - top1CorrectPred = 0 - topKCorrectPred = 0 - for i in range(numImages): - if (groundTruthLabels[i] == predictions[i][-1]): - top1CorrectPred += 1 - if (groundTruthLabels[i] in predictions[i]): - topKCorrectPred += 1 - return (top1CorrectPred/(1.0*numImages), topKCorrectPred/(1.0*numImages)) + global groundTruthLabels + top1CorrectPred = 0 + topKCorrectPred = 0 + for i in range(numImages): + if groundTruthLabels[i] == predictions[i][-1]: + top1CorrectPred += 1 + if groundTruthLabels[i] in predictions[i]: + topKCorrectPred += 1 + return (top1CorrectPred / (1.0 * numImages), topKCorrectPred / (1.0 * numImages)) for curScale in ScalesToCheck: - predictions = [[None]*topK for _ in range(numImages)] - for curProcessNum in range(NumProcesses): - curFileName = os.path.join(inferenceOutputDirectory, 'output_' + str(curScale) + '_' + str(curProcessNum) + '.outp') - parseInferenceOutputFile(predictions, curFileName) - for i in range(numImages): - for j in range(topK): - assert(predictions[i][j] is not None) - (top1Acc, topKAcc) = calculateAccuracy(predictions) - print("curScale = " + str(curScale) + ", top1Acc = " + str(top1Acc) + ", topKAcc = " + str(topKAcc)) - - + predictions = [[None] * topK for _ in range(numImages)] + for curProcessNum in range(NumProcesses): + curFileName = os.path.join( + inferenceOutputDirectory, + "output_" + str(curScale) + "_" + str(curProcessNum) + ".outp", + ) + parseInferenceOutputFile(predictions, curFileName) + for i in range(numImages): + for j in range(topK): + assert predictions[i][j] is not None + (top1Acc, topKAcc) = calculateAccuracy(predictions) + print( + "curScale = " + + str(curScale) + + ", top1Acc = " + + str(top1Acc) + + ", topKAcc = " + + str(topKAcc) + ) diff --git a/Athos/HelperScripts/FindAccuracy_Porthos.py b/Athos/HelperScripts/FindAccuracy_Porthos.py index 794c64b9..0ab7917e 100644 --- a/Athos/HelperScripts/FindAccuracy_Porthos.py +++ b/Athos/HelperScripts/FindAccuracy_Porthos.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,25 +20,50 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" # Use this script to find accuracy once the full exploration has been done on relevant scale factors. # NOTE: The ground truth labels are in [1, 1000]. -# Resnet outputs labels in [0,1000] -- class 0 is extraneous and is the other category. -# For DenseNet and SqueezeNet, the labels after argmax are in [0,999]: -# so either we have to do ground truth labels-1 or while outputing from the code for SqNet/DenseNet, -# add a +1. Choosing to go with the former. -# So, in summary, when running resnet, use the last parameter as 1, while for SqueezeNet/DenseNet use it as 0. +# Resnet outputs labels in [0,1000] -- class 0 is extraneous and is the other category. +# For DenseNet and SqueezeNet, the labels after argmax are in [0,999]: +# so either we have to do ground truth labels-1 or while outputing from the code for SqNet/DenseNet, +# add a +1. Choosing to go with the former. +# So, in summary, when running resnet, use the last parameter as 1, while for SqueezeNet/DenseNet use it as 0. import os, sys import numpy as np -if (len(sys.argv) < 4): - print("Usage : python3 FindAccuracy_Porthos.py ") - exit(1) +if len(sys.argv) < 4: + print( + "Usage : python3 FindAccuracy_Porthos.py " + ) + exit(1) # Change following parameters accordingly -ScalesToCheck = [9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +ScalesToCheck = [ + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, +] # ScalesToCheck = [11] numImages = 96 topK = 5 @@ -46,63 +71,79 @@ groundTruthLabelsFileName = sys.argv[1] inferenceOutputDirectory = sys.argv[2] lowerBoundOfOutputLabels = int(sys.argv[3]) -if (lowerBoundOfOutputLabels != 0 and lowerBoundOfOutputLabels != 1): - print("lowerBoundOfOutputLabels should be either 0 or 1. Exiting.", file=sys.stderr) - exit(1) +if lowerBoundOfOutputLabels != 0 and lowerBoundOfOutputLabels != 1: + print("lowerBoundOfOutputLabels should be either 0 or 1. Exiting.", file=sys.stderr) + exit(1) -with open(groundTruthLabelsFileName, 'r') as ff: - groundTruthLabels = ff.readlines() +with open(groundTruthLabelsFileName, "r") as ff: + groundTruthLabels = ff.readlines() + +groundTruthLabels = list( + map(lambda x: int(x.rstrip()), groundTruthLabels) +) # For imagenet, this is in [1,1000] +if lowerBoundOfOutputLabels == 0: + groundTruthLabels = list( + map(lambda x: x - 1, groundTruthLabels) + ) # If the labels in the output start from 0, + # subtract 1 from the ground truth labels. -groundTruthLabels = list(map(lambda x : int(x.rstrip()), groundTruthLabels)) #For imagenet, this is in [1,1000] -if (lowerBoundOfOutputLabels==0): - groundTruthLabels = list(map(lambda x : x-1, groundTruthLabels)) #If the labels in the output start from 0, - # subtract 1 from the ground truth labels. def parseInferenceOutputFile(predictions, i, outputFileName): - with open(outputFileName, 'r') as ff: - lines = ff.readlines() - lines = list(map(lambda x : x.rstrip(), lines)) - lines = list(filter(lambda x : x!='', lines)) - if(len(lines)!=2): - print("Error in parsing : "+outputFileName) - assert(False) - imgCounter = None - for line in lines: - if (line.startswith('Answer for')): - imgCounter = int(line.split('=')[1].split(':')[0]) #This is assumed to be 0-indexed - else: - assert(imgCounter is not None) - preds = line.split() - # print(imgCounter,preds) - preds = np.array(list(map(lambda x : int(x), preds))) - topKPredsIdx = np.argpartition(preds, -1*topK)[-1*topK:] - topKPredsIdx = topKPredsIdx[np.argsort(preds[topKPredsIdx])] - for i, val in enumerate(topKPredsIdx): - predictions[imgCounter][i] = val + with open(outputFileName, "r") as ff: + lines = ff.readlines() + lines = list(map(lambda x: x.rstrip(), lines)) + lines = list(filter(lambda x: x != "", lines)) + if len(lines) != 2: + print("Error in parsing : " + outputFileName) + assert False + imgCounter = None + for line in lines: + if line.startswith("Answer for"): + imgCounter = int( + line.split("=")[1].split(":")[0] + ) # This is assumed to be 0-indexed + else: + assert imgCounter is not None + preds = line.split() + # print(imgCounter,preds) + preds = np.array(list(map(lambda x: int(x), preds))) + topKPredsIdx = np.argpartition(preds, -1 * topK)[-1 * topK :] + topKPredsIdx = topKPredsIdx[np.argsort(preds[topKPredsIdx])] + for i, val in enumerate(topKPredsIdx): + predictions[imgCounter][i] = val + def calculateAccuracy(predictions): - global groundTruthLabels - top1CorrectPred = 0 - topKCorrectPred = 0 - for i in range(numImages): - if not(predictions[i][0]): - continue - if (groundTruthLabels[i] == predictions[i][-1]): - top1CorrectPred += 1 - if (groundTruthLabels[i] in predictions[i]): - topKCorrectPred += 1 - return (top1CorrectPred/(1.0*numImages), topKCorrectPred/(1.0*numImages)) + global groundTruthLabels + top1CorrectPred = 0 + topKCorrectPred = 0 + for i in range(numImages): + if not (predictions[i][0]): + continue + if groundTruthLabels[i] == predictions[i][-1]: + top1CorrectPred += 1 + if groundTruthLabels[i] in predictions[i]: + topKCorrectPred += 1 + return (top1CorrectPred / (1.0 * numImages), topKCorrectPred / (1.0 * numImages)) for curScale in ScalesToCheck: - predictions = [[None]*topK for _ in range(numImages)] - for i in range(numImages): - curFileName = os.path.join(inferenceOutputDirectory, 'stderr_' + str(curScale) + '_' + str(i) +'_proc_1.outp') - parseInferenceOutputFile(predictions, i, curFileName) - for i in range(numImages): - for j in range(topK): - assert(predictions[i][j] is not None) - (top1Acc, topKAcc) = calculateAccuracy(predictions) - print("curScale = " + str(curScale) + ", top1Acc = " + str(top1Acc) + ", topKAcc = " + str(topKAcc)) - - + predictions = [[None] * topK for _ in range(numImages)] + for i in range(numImages): + curFileName = os.path.join( + inferenceOutputDirectory, + "stderr_" + str(curScale) + "_" + str(i) + "_proc_1.outp", + ) + parseInferenceOutputFile(predictions, i, curFileName) + for i in range(numImages): + for j in range(topK): + assert predictions[i][j] is not None + (top1Acc, topKAcc) = calculateAccuracy(predictions) + print( + "curScale = " + + str(curScale) + + ", top1Acc = " + + str(top1Acc) + + ", topKAcc = " + + str(topKAcc) + ) diff --git a/Athos/HelperScripts/FindAccuracy_TF.py b/Athos/HelperScripts/FindAccuracy_TF.py index d0266d78..1212dca5 100644 --- a/Athos/HelperScripts/FindAccuracy_TF.py +++ b/Athos/HelperScripts/FindAccuracy_TF.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,22 +20,24 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" # Use this script to find accuracy once the full exploration has been done on relevant scale factors. # NOTE: The ground truth labels are in [1, 1000]. -# Resnet outputs labels in [0,1000] -- class 0 is extraneous and is the other category. -# For DenseNet and SqueezeNet, the labels after argmax are in [0,999]: -# so either we have to do ground truth labels-1 or while outputing from the code for SqNet/DenseNet, -# add a +1. Choosing to go with the former. -# So, in summary, when running resnet, use the last parameter as 1, while for SqueezeNet/DenseNet use it as 0. +# Resnet outputs labels in [0,1000] -- class 0 is extraneous and is the other category. +# For DenseNet and SqueezeNet, the labels after argmax are in [0,999]: +# so either we have to do ground truth labels-1 or while outputing from the code for SqNet/DenseNet, +# add a +1. Choosing to go with the former. +# So, in summary, when running resnet, use the last parameter as 1, while for SqueezeNet/DenseNet use it as 0. import os, sys import numpy as np -if (len(sys.argv) < 5): - print("Usage : python3 FindAccuracy.py ") - exit(1) +if len(sys.argv) < 5: + print( + "Usage : python3 FindAccuracy.py " + ) + exit(1) numImages = 50000 topK = 5 @@ -44,57 +46,66 @@ tfInferenceArgmaxFileName = sys.argv[2] tfInferenceAllOutputFileName = sys.argv[3] lowerBoundOfOutputLabels = int(sys.argv[4]) -if (lowerBoundOfOutputLabels != 0 and lowerBoundOfOutputLabels != 1): - print("lowerBoundOfOutputLabels should be either 0 or 1. Exiting.", file=sys.stderr) - exit(1) +if lowerBoundOfOutputLabels != 0 and lowerBoundOfOutputLabels != 1: + print("lowerBoundOfOutputLabels should be either 0 or 1. Exiting.", file=sys.stderr) + exit(1) + +with open(groundTruthLabelsFileName, "r") as ff: + groundTruthLabels = ff.readlines() + +groundTruthLabels = list( + map(lambda x: int(x.rstrip()), groundTruthLabels) +) # For imagenet, this is in [1,1000] +if lowerBoundOfOutputLabels == 0: + groundTruthLabels = list( + map(lambda x: x - 1, groundTruthLabels) + ) # If the labels in the output start from 0, + # subtract 1 from the ground truth labels. +with open(tfInferenceArgmaxFileName, "r") as ff: + tfInferenceArgmaxOutputs = ff.readlines() +tfInferenceArgmaxOutputs = list( + map(lambda x: int(x.rstrip()), tfInferenceArgmaxOutputs) +) -with open(groundTruthLabelsFileName, 'r') as ff: - groundTruthLabels = ff.readlines() - -groundTruthLabels = list(map(lambda x : int(x.rstrip()), groundTruthLabels)) #For imagenet, this is in [1,1000] -if (lowerBoundOfOutputLabels==0): - groundTruthLabels = list(map(lambda x : x-1, groundTruthLabels)) #If the labels in the output start from 0, - # subtract 1 from the ground truth labels. -with open(tfInferenceArgmaxFileName, 'r') as ff: - tfInferenceArgmaxOutputs = ff.readlines() -tfInferenceArgmaxOutputs = list(map(lambda x : int(x.rstrip()), tfInferenceArgmaxOutputs)) def parseInferenceOutputFile(predictions, outputFileName): - with open(outputFileName, 'r') as ff: - lines = ff.readlines() - lines = list(map(lambda x : x.rstrip(), lines)) - lines = list(filter(lambda x : x!='', lines)) - assert(len(lines)%2==0) - imgCounter = None - for line in lines: - if (line.startswith('Answer for')): - imgCounter = int(line.split('=')[1]) #This is assumed to be 1-indexed - else: - assert(imgCounter is not None) - preds = line.split() - preds = np.array(list(map(lambda x : float(x), preds))) - topKPredsIdx = np.argpartition(preds, -1*topK)[-1*topK:] - topKPredsIdx = topKPredsIdx[np.argsort(preds[topKPredsIdx])] - for i,val in enumerate(topKPredsIdx): - predictions[imgCounter-1][i] = val + with open(outputFileName, "r") as ff: + lines = ff.readlines() + lines = list(map(lambda x: x.rstrip(), lines)) + lines = list(filter(lambda x: x != "", lines)) + assert len(lines) % 2 == 0 + imgCounter = None + for line in lines: + if line.startswith("Answer for"): + imgCounter = int(line.split("=")[1]) # This is assumed to be 1-indexed + else: + assert imgCounter is not None + preds = line.split() + preds = np.array(list(map(lambda x: float(x), preds))) + topKPredsIdx = np.argpartition(preds, -1 * topK)[-1 * topK :] + topKPredsIdx = topKPredsIdx[np.argsort(preds[topKPredsIdx])] + for i, val in enumerate(topKPredsIdx): + predictions[imgCounter - 1][i] = val + def calculateAccuracy(predictions): - global groundTruthLabels - top1CorrectPred = 0 - topKCorrectPred = 0 - for i in range(numImages): - if (groundTruthLabels[i] == predictions[i][-1]): - top1CorrectPred += 1 - if (groundTruthLabels[i] in predictions[i]): - topKCorrectPred += 1 - return (top1CorrectPred/(1.0*numImages), topKCorrectPred/(1.0*numImages)) - -predictions = [[None]*topK for _ in range(numImages)] + global groundTruthLabels + top1CorrectPred = 0 + topKCorrectPred = 0 + for i in range(numImages): + if groundTruthLabels[i] == predictions[i][-1]: + top1CorrectPred += 1 + if groundTruthLabels[i] in predictions[i]: + topKCorrectPred += 1 + return (top1CorrectPred / (1.0 * numImages), topKCorrectPred / (1.0 * numImages)) + + +predictions = [[None] * topK for _ in range(numImages)] parseInferenceOutputFile(predictions, tfInferenceAllOutputFileName) for i in range(numImages): - assert(predictions[i][-1]==tfInferenceArgmaxOutputs[i]) - for j in range(topK): - assert(predictions[i][j] is not None) + assert predictions[i][-1] == tfInferenceArgmaxOutputs[i] + for j in range(topK): + assert predictions[i][j] is not None (top1Acc, topKAcc) = calculateAccuracy(predictions) print(top1Acc, topKAcc) @@ -117,7 +128,7 @@ def calculateAccuracy(predictions): # if (groundTruthLabels[idx] == predictions[idx][-1]): # i25kNotSelectedCorrectTop1 += 1 # if (groundTruthLabels[idx] in predictions[idx]): -# i25kNotSelectedCorrectTop5 += 1 +# i25kNotSelectedCorrectTop5 += 1 # assert(i25kSelectedCorrectTop1 + i25kNotSelectedCorrectTop1 == (top1Acc*numImages)) # assert(i25kSelectedCorrectTop5 + i25kNotSelectedCorrectTop5 == (topKAcc*numImages)) @@ -125,4 +136,3 @@ def calculateAccuracy(predictions): # print(i25kNotSelectedCorrectTop1, i25kNotSelectedCorrectTop5) # print(i25kSelectedCorrectTop1/(1.0*25000), i25kSelectedCorrectTop5/(1.0*25000)) # print(i25kNotSelectedCorrectTop1/(1.0*25000), i25kNotSelectedCorrectTop5/(1.0*25000)) - diff --git a/Athos/HelperScripts/Random_Image_Selection.py b/Athos/HelperScripts/Random_Image_Selection.py index b5571e6c..d70ebadf 100644 --- a/Athos/HelperScripts/Random_Image_Selection.py +++ b/Athos/HelperScripts/Random_Image_Selection.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,23 +20,23 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" # Choose a random subset of SELECT_IMAGES from TOTAL_IMAGES import random -TOTAL_IMAGES = 50000 #ImageNet validation dataset size is 50k images -SELECT_IMAGES = 1000 #Choose a random subset of 1k images -with open('./ImageNet_ValData/imagenet12_val_nlabels.txt', 'r') as ff: - labels = ff.readlines() +TOTAL_IMAGES = 50000 # ImageNet validation dataset size is 50k images +SELECT_IMAGES = 1000 # Choose a random subset of 1k images +with open("./ImageNet_ValData/imagenet12_val_nlabels.txt", "r") as ff: + labels = ff.readlines() idxSelected = random.sample(range(TOTAL_IMAGES), SELECT_IMAGES) labelSelected = [labels[i] for i in idxSelected] -with open("./ImageNet_ValData/Random_Img_Idx.txt", 'w') as ff: - for x in idxSelected: - ff.write(str(x+1)+"\n") #make it 1-indexed as image numbers are 1 indexed +with open("./ImageNet_ValData/Random_Img_Idx.txt", "w") as ff: + for x in idxSelected: + ff.write(str(x + 1) + "\n") # make it 1-indexed as image numbers are 1 indexed with open("./ImageNet_ValData/Random_Img_Labels.txt", "w") as ff: - for x in labelSelected: - ff.write(str(x)) + for x in labelSelected: + ff.write(str(x)) diff --git a/Athos/HelperScripts/Scale_img_and_model.py b/Athos/HelperScripts/Scale_img_and_model.py index b880e210..ac9d00f7 100644 --- a/Athos/HelperScripts/Scale_img_and_model.py +++ b/Athos/HelperScripts/Scale_img_and_model.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,67 +20,82 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import os, sys numImages = 10 -newScaledFileNameSuffix = '_scaled_' +newScaledFileNameSuffix = "_scaled_" scalingFactors = [12] -if (len(sys.argv)!=3): - print("Incorrect args: Run as python3 Scale_img_and_model.py ", file=sys.stderr) - exit(1) +if len(sys.argv) != 3: + print( + "Incorrect args: Run as python3 Scale_img_and_model.py ", + file=sys.stderr, + ) + exit(1) modelFileName = sys.argv[1] floatImgDir = sys.argv[2] randomImgIdxFileName = "./ImageNet_ValData/Random_Img_Idx.txt" + def checkIfFileExists(filename): - if not(os.path.exists(filename)): - print("Expected file doesn't exist. Error. FileName = {0}.".format(filename), file=sys.stderr) - exit(1) + if not (os.path.exists(filename)): + print( + "Expected file doesn't exist. Error. FileName = {0}.".format(filename), + file=sys.stderr, + ) + exit(1) + checkIfFileExists(randomImgIdxFileName) -with open(randomImgIdxFileName, 'r') as ff: - imgIdx = ff.readlines() -imgIdx = list(map(lambda x: int(x.rstrip()) , imgIdx)) +with open(randomImgIdxFileName, "r") as ff: + imgIdx = ff.readlines() +imgIdx = list(map(lambda x: int(x.rstrip()), imgIdx)) + def scaleImg(filename, scalingFac): - checkIfFileExists(filename) - with open(filename, 'r') as ff: - line = ff.readline() - val = line.split() - val = list(map(lambda x : int(float(x)*(1< maxibits): - maxibits = curbits - line = ff.readline() +with open(filename, "r") as ff: + line = ff.readline() + while line: + if not (line.startswith("Matmul")): + val = line.split() + val = list(map(lambda x: int(x), val)) + for elem in val: + curbits = len(bin(elem)) - 2 + if curbits in bitsdict: + bitsdict[curbits] += 1 + else: + bitsdict[curbits] = 0 + if curbits > maxibits: + maxibits = curbits + line = ff.readline() print(maxibits) summ = 0 for k in sorted(bitsdict.keys()): - print(k,bitsdict[k]) - summ+=bitsdict[k] + print(k, bitsdict[k]) + summ += bitsdict[k] print("summ = ", summ) prob = 0 for k in sorted(bitsdict.keys()): - curprob = bitsdict[k]/(1<<(64-k-1)) - print("curprob ",k,curprob) - prob += curprob + curprob = bitsdict[k] / (1 << (64 - k - 1)) + print("curprob ", k, curprob) + prob += curprob print(prob) diff --git a/Athos/Networks/ChestXRay/ChestXRay_tf_main.py b/Athos/Networks/ChestXRay/ChestXRay_tf_main.py index 312ff3ae..ee7fcb30 100644 --- a/Athos/Networks/ChestXRay/ChestXRay_tf_main.py +++ b/Athos/Networks/ChestXRay/ChestXRay_tf_main.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import cv2, numpy, sys, os, argparse, time import tensorflow as tf @@ -28,6 +28,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) from tensorflow.python.util import deprecation + deprecation._PRINT_DEPRECATION_WARNINGS = False try: from tensorflow.python.util import module_wrapper as deprecation @@ -36,83 +37,131 @@ deprecation._PER_MODULE_WARNING_LIMIT = 0 -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "TFCompiler")) import DumpTFMtData + def get_preprocessed_image(filename): - resized_height = 224 - resized_width = 224 - test_image_path = filename - cv2_image = cv2.resize(cv2.imread(test_image_path), (resized_height, resized_width))#.astype(numpy.float32) - cv2_image = cv2_image - numpy.min(cv2_image) - cv2_image = cv2_image/numpy.ptp(cv2_image) - cv2_image = 255*cv2_image - cv2_image = cv2_image.astype('uint8') - return cv2_image + resized_height = 224 + resized_width = 224 + test_image_path = filename + cv2_image = cv2.resize( + cv2.imread(test_image_path), (resized_height, resized_width) + ) # .astype(numpy.float32) + cv2_image = cv2_image - numpy.min(cv2_image) + cv2_image = cv2_image / numpy.ptp(cv2_image) + cv2_image = 255 * cv2_image + cv2_image = cv2_image.astype("uint8") + return cv2_image -def parseArgs(): - parser = argparse.ArgumentParser() - parser.add_argument("--savePreTrainedWeightsInt", type=bool, default=False, help="savePreTrainedWeightsInt") - parser.add_argument("--savePreTrainedWeightsFloat", type=bool, default=False, help="savePreTrainedWeightsFloat") - parser.add_argument("--scalingFac", type=int, default=15, help="scalingFac") - parser.add_argument("--runPrediction", type=bool, default=False, help="runPrediction") - parser.add_argument("--saveImgAndWtData", type=bool, default=False, help="saveImgAndWtData") +def parseArgs(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--savePreTrainedWeightsInt", + type=bool, + default=False, + help="savePreTrainedWeightsInt", + ) + parser.add_argument( + "--savePreTrainedWeightsFloat", + type=bool, + default=False, + help="savePreTrainedWeightsFloat", + ) + parser.add_argument("--scalingFac", type=int, default=15, help="scalingFac") + parser.add_argument( + "--runPrediction", type=bool, default=False, help="runPrediction" + ) + parser.add_argument( + "--saveImgAndWtData", type=bool, default=False, help="saveImgAndWtData" + ) + + args = parser.parse_args() + return args - args = parser.parse_args() - return args args = parseArgs() -imagesTemp = get_preprocessed_image('./SampleImages/00014251_029.png') -images = numpy.zeros(shape=(1,224,224,3)) +imagesTemp = get_preprocessed_image("./SampleImages/00014251_029.png") +images = numpy.zeros(shape=(1, 224, 224, 3)) images[0] = imagesTemp -feed_dict={'input_1:0' : images} +feed_dict = {"input_1:0": images} with tf.Session() as sess: - saver = tf.train.import_meta_graph("./PreTrainedModel/TFModel/model.meta") - sess.run(tf.global_variables_initializer()) - - # Find output tensor - output_tensor = None - gg = tf.get_default_graph() - for node in gg.as_graph_def().node: - if node.name == 'dense_1/Sigmoid': - output_tensor = gg.get_operation_by_name(node.name).outputs[0] - - assert(output_tensor is not None) - optimized_graph_def = DumpTFMtData.save_graph_metadata(output_tensor, sess, feed_dict) - - if args.savePreTrainedWeightsInt or args.savePreTrainedWeightsFloat or args.runPrediction or args.saveImgAndWtData: - saver.restore(sess, "./PreTrainedModel/TFModel/model") - if args.savePreTrainedWeightsInt or args.savePreTrainedWeightsFloat or args.saveImgAndWtData: - DumpTFMtData.updateWeightsForBN(optimized_graph_def, sess, feed_dict) - - predictions = None - if args.runPrediction: - print("*************** Starting Prediction****************") - start_time = time.time() - predictions = sess.run(output_tensor, feed_dict=feed_dict) - end_time = time.time() - print("*************** Done Prediction****************") - duration = end_time - start_time - print("Time taken in inference : ", duration) - print(predictions) - with open('tf_pred.float','w+') as f: - f.write(DumpTFMtData.numpy_float_array_to_float_val_str(predictions)) - with open('tf_pred.time','w') as f: - f.write(str(round(duration, 2))) - - trainVarsName = [] - for node in optimized_graph_def.node: - if node.op=="VariableV2": - trainVarsName.append(node.name) - trainVars = list(map(lambda x : tf.get_default_graph().get_operation_by_name(x).outputs[0] , trainVarsName)) - if args.savePreTrainedWeightsInt: - DumpTFMtData.dumpTrainedWeights(sess, trainVars, "model_weights_scale_{}.inp".format(args.scalingFac), args.scalingFac, 'w') - if args.savePreTrainedWeightsFloat: - DumpTFMtData.dumpTrainedWeightsFloat(sess, trainVars, 'model_weights_float.inp', 'w') - if args.saveImgAndWtData: - DumpTFMtData.dumpImgAndWeightsDataSeparate(sess, images[0], trainVars, "model_input_scale_{}.inp".format(args.scalingFac), - "model_weights_scale_{}.inp".format(args.scalingFac), args.scalingFac) - + saver = tf.train.import_meta_graph("./PreTrainedModel/TFModel/model.meta") + sess.run(tf.global_variables_initializer()) + + # Find output tensor + output_tensor = None + gg = tf.get_default_graph() + for node in gg.as_graph_def().node: + if node.name == "dense_1/Sigmoid": + output_tensor = gg.get_operation_by_name(node.name).outputs[0] + + assert output_tensor is not None + optimized_graph_def = DumpTFMtData.save_graph_metadata( + output_tensor, sess, feed_dict + ) + + if ( + args.savePreTrainedWeightsInt + or args.savePreTrainedWeightsFloat + or args.runPrediction + or args.saveImgAndWtData + ): + saver.restore(sess, "./PreTrainedModel/TFModel/model") + if ( + args.savePreTrainedWeightsInt + or args.savePreTrainedWeightsFloat + or args.saveImgAndWtData + ): + DumpTFMtData.updateWeightsForBN(optimized_graph_def, sess, feed_dict) + + predictions = None + if args.runPrediction: + print("*************** Starting Prediction****************") + start_time = time.time() + predictions = sess.run(output_tensor, feed_dict=feed_dict) + end_time = time.time() + print("*************** Done Prediction****************") + duration = end_time - start_time + print("Time taken in inference : ", duration) + print(predictions) + with open("tf_pred.float", "w+") as f: + f.write(DumpTFMtData.numpy_float_array_to_float_val_str(predictions)) + with open("tf_pred.time", "w") as f: + f.write(str(round(duration, 2))) + + trainVarsName = [] + for node in optimized_graph_def.node: + if node.op == "VariableV2": + trainVarsName.append(node.name) + trainVars = list( + map( + lambda x: tf.get_default_graph().get_operation_by_name(x).outputs[0], + trainVarsName, + ) + ) + if args.savePreTrainedWeightsInt: + DumpTFMtData.dumpTrainedWeights( + sess, + trainVars, + "model_weights_scale_{}.inp".format(args.scalingFac), + args.scalingFac, + "w", + ) + if args.savePreTrainedWeightsFloat: + DumpTFMtData.dumpTrainedWeightsFloat( + sess, trainVars, "model_weights_float.inp", "w" + ) + if args.saveImgAndWtData: + DumpTFMtData.dumpImgAndWeightsDataSeparate( + sess, + images[0], + trainVars, + "model_input_scale_{}.inp".format(args.scalingFac), + "model_weights_scale_{}.inp".format(args.scalingFac), + args.scalingFac, + ) diff --git a/Athos/Networks/DenseNet/AccuracyAnalysisHelper/DenseNet_main_float_acc.py b/Athos/Networks/DenseNet/AccuracyAnalysisHelper/DenseNet_main_float_acc.py index bcdf6a35..1e3cdcfb 100644 --- a/Athos/Networks/DenseNet/AccuracyAnalysisHelper/DenseNet_main_float_acc.py +++ b/Athos/Networks/DenseNet/AccuracyAnalysisHelper/DenseNet_main_float_acc.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import numpy import argparse @@ -28,55 +28,63 @@ import tensorflow as tf import _pickle as pickle -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) import nets_factory batchsize = 1000 N = 50000 -model_name = 'densenet121' +model_name = "densenet121" num_classes = 1000 network_fn = nets_factory.get_network_fn( - model_name, - num_classes=num_classes, - is_training=False) + model_name, num_classes=num_classes, is_training=False +) -finalActivationsFileName = 'floating_point_acc.outp' -argmaxOutputFileName = 'floating_point_argmax.outp' +finalActivationsFileName = "floating_point_acc.outp" +argmaxOutputFileName = "floating_point_argmax.outp" -imagesPlaceHolder = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name='input_x') +imagesPlaceHolder = tf.placeholder( + tf.float32, shape=(None, 224, 224, 3), name="input_x" +) logits, end_points = network_fn(imagesPlaceHolder) with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) + sess.run(tf.global_variables_initializer()) - modelPath = '../PreTrainedModel/tf-densenet121.ckpt' - saver = tf.train.Saver() - saver.restore(sess, modelPath) - - with open(finalActivationsFileName,'w') as ff: - pass - with open(argmaxOutputFileName,'w') as ff: - pass - numbatches = N//batchsize - for batchNum in range(numbatches): - startImgNum = (batchNum*batchsize) + 1 - endImgNum = N if (batchNum == numbatches-1) else (((batchNum+1)*batchsize)) - print("Processing images from start,end = {0}, {1}".format(startImgNum, endImgNum)) - images = numpy.zeros(shape=(endImgNum-startImgNum+1,224,224,3)) - for curImgNum in range(startImgNum, endImgNum+1): - with open('./PreProcessedImages/ImageNum_'+str(curImgNum)+'.inp', 'r') as ff: - line = ff.readline() - images[curImgNum-startImgNum] = numpy.reshape(list(map(lambda x : float(x), line.split())), (224,224,3)) - feed_dict = {imagesPlaceHolder: images} - predictions = sess.run(logits, feed_dict=feed_dict) - with open(finalActivationsFileName, 'a') as ff: - with open(argmaxOutputFileName, 'a') as gg: - for i in range(endImgNum-startImgNum+1): - ff.write('Answer for imgCounter = ' + str(startImgNum+i) + '\n') - for elem in numpy.nditer(predictions[i],order='C'): - ff.write(str(elem)+' ') - ff.write('\n\n') - gg.write('Answer for imgCounter = '+str(startImgNum+i)+' is ') - gg.write(str(numpy.argmax(predictions[i], 2))+'\n') + modelPath = "../PreTrainedModel/tf-densenet121.ckpt" + saver = tf.train.Saver() + saver.restore(sess, modelPath) + with open(finalActivationsFileName, "w") as ff: + pass + with open(argmaxOutputFileName, "w") as ff: + pass + numbatches = N // batchsize + for batchNum in range(numbatches): + startImgNum = (batchNum * batchsize) + 1 + endImgNum = ( + N if (batchNum == numbatches - 1) else (((batchNum + 1) * batchsize)) + ) + print( + "Processing images from start,end = {0}, {1}".format(startImgNum, endImgNum) + ) + images = numpy.zeros(shape=(endImgNum - startImgNum + 1, 224, 224, 3)) + for curImgNum in range(startImgNum, endImgNum + 1): + with open( + "./PreProcessedImages/ImageNum_" + str(curImgNum) + ".inp", "r" + ) as ff: + line = ff.readline() + images[curImgNum - startImgNum] = numpy.reshape( + list(map(lambda x: float(x), line.split())), (224, 224, 3) + ) + feed_dict = {imagesPlaceHolder: images} + predictions = sess.run(logits, feed_dict=feed_dict) + with open(finalActivationsFileName, "a") as ff: + with open(argmaxOutputFileName, "a") as gg: + for i in range(endImgNum - startImgNum + 1): + ff.write("Answer for imgCounter = " + str(startImgNum + i) + "\n") + for elem in numpy.nditer(predictions[i], order="C"): + ff.write(str(elem) + " ") + ff.write("\n\n") + gg.write("Answer for imgCounter = " + str(startImgNum + i) + " is ") + gg.write(str(numpy.argmax(predictions[i], 2)) + "\n") diff --git a/Athos/Networks/DenseNet/DenseNet_main.py b/Athos/Networks/DenseNet/DenseNet_main.py index 4531aa5e..147085db 100644 --- a/Athos/Networks/DenseNet/DenseNet_main.py +++ b/Athos/Networks/DenseNet/DenseNet_main.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import numpy import argparse @@ -31,6 +31,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) from tensorflow.python.util import deprecation + deprecation._PRINT_DEPRECATION_WARNINGS = False try: from tensorflow.python.util import module_wrapper as deprecation @@ -39,17 +40,19 @@ deprecation._PER_MODULE_WARNING_LIMIT = 0 import nets_factory -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "TFCompiler")) import DumpTFMtData -model_name = 'densenet121' +model_name = "densenet121" num_classes = 1000 network_fn = nets_factory.get_network_fn( - model_name, - num_classes=num_classes, - is_training=False) + model_name, num_classes=num_classes, is_training=False +) -imagesPlaceHolder = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name='input_x') +imagesPlaceHolder = tf.placeholder( + tf.float32, shape=(None, 224, 224, 3), name="input_x" +) logits, end_points = network_fn(imagesPlaceHolder) pred = tf.argmax(logits, 3) @@ -60,73 +63,118 @@ # imagesTemp = list(map(lambda x : float(x), line)) # imagesTemp = numpy.reshape(imagesTemp, (224,224,3)) -sampleImageFilePath = './SampleImages/n02109961_36_denseNet_preprocessed.pkl' -with open(sampleImageFilePath, 'rb') as ff: - imagesTemp = pickle.load(ff) -images = numpy.zeros(shape=(1,224,224,3)) +sampleImageFilePath = "./SampleImages/n02109961_36_denseNet_preprocessed.pkl" +with open(sampleImageFilePath, "rb") as ff: + imagesTemp = pickle.load(ff) +images = numpy.zeros(shape=(1, 224, 224, 3)) images[0] = imagesTemp -feed_dict = {imagesPlaceHolder : images} +feed_dict = {imagesPlaceHolder: images} -def parseArgs(): - parser = argparse.ArgumentParser() - parser.add_argument("--savePreTrainedWeightsInt", type=bool, default=False, help="savePreTrainedWeightsInt") - parser.add_argument("--savePreTrainedWeightsFloat", type=bool, default=False, help="savePreTrainedWeightsFloat") - parser.add_argument("--scalingFac", type=int, default=15, help="scalingFac") - parser.add_argument("--runPrediction", type=bool, default=False, help="runPrediction") - parser.add_argument("--saveImgAndWtData", type=bool, default=False, help="saveImgAndWtData") +def parseArgs(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--savePreTrainedWeightsInt", + type=bool, + default=False, + help="savePreTrainedWeightsInt", + ) + parser.add_argument( + "--savePreTrainedWeightsFloat", + type=bool, + default=False, + help="savePreTrainedWeightsFloat", + ) + parser.add_argument("--scalingFac", type=int, default=15, help="scalingFac") + parser.add_argument( + "--runPrediction", type=bool, default=False, help="runPrediction" + ) + parser.add_argument( + "--saveImgAndWtData", type=bool, default=False, help="saveImgAndWtData" + ) + + args = parser.parse_args() + return args - args = parser.parse_args() - return args args = parseArgs() with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - - output_tensor = None - gg = tf.get_default_graph() - for node in gg.as_graph_def().node: - # if node.name == 'densenet121/logits/BiasAdd': - if node.name == 'ArgMax': - output_tensor = gg.get_operation_by_name(node.name).outputs[0] - - assert(output_tensor is not None) - optimized_graph_def = DumpTFMtData.save_graph_metadata(output_tensor, sess, feed_dict) - - if args.savePreTrainedWeightsInt or args.savePreTrainedWeightsFloat or args.runPrediction or args.saveImgAndWtData: - modelPath = './PreTrainedModel/tf-densenet121.ckpt' - saver = tf.train.Saver() - saver.restore(sess, modelPath) - if args.savePreTrainedWeightsInt or args.savePreTrainedWeightsFloat or args.saveImgAndWtData: - DumpTFMtData.updateWeightsForBN(optimized_graph_def, sess, feed_dict) - - predictions = None - if args.runPrediction: - print("*************** Starting Prediction****************") - start_time = time.time() - predictions = sess.run(output_tensor, feed_dict=feed_dict) - end_time = time.time() - print("*************** Done Prediction****************") - duration = end_time - start_time - print("Time taken in inference : ", duration) - with open('tf_pred.float','w+') as f: - f.write(DumpTFMtData.numpy_float_array_to_float_val_str(predictions)) - with open('tf_pred.time','w') as f: - f.write(str(round(duration, 2))) - - print("Prediction = ", predictions) - - trainVarsName = [] - for node in optimized_graph_def.node: - if node.op=="VariableV2": - trainVarsName.append(node.name) - trainVars = list(map(lambda x : tf.get_default_graph().get_operation_by_name(x).outputs[0] , trainVarsName)) - if args.savePreTrainedWeightsInt: - DumpTFMtData.dumpTrainedWeightsInt(sess, trainVars, "model_weights_scale_{}.inp".format(args.scalingFac), args.scalingFac, 'w') - if args.savePreTrainedWeightsFloat: - DumpTFMtData.dumpTrainedWeightsFloat(sess, trainVars, 'model_weights_float.inp', 'w') - if args.saveImgAndWtData: - DumpTFMtData.dumpImgAndWeightsDataSeparate(sess, images[0], trainVars, "model_input_scale_{}.inp".format(args.scalingFac), - "model_weights_scale_{}.inp".format(args.scalingFac), args.scalingFac) - + sess.run(tf.global_variables_initializer()) + + output_tensor = None + gg = tf.get_default_graph() + for node in gg.as_graph_def().node: + # if node.name == 'densenet121/logits/BiasAdd': + if node.name == "ArgMax": + output_tensor = gg.get_operation_by_name(node.name).outputs[0] + + assert output_tensor is not None + optimized_graph_def = DumpTFMtData.save_graph_metadata( + output_tensor, sess, feed_dict + ) + + if ( + args.savePreTrainedWeightsInt + or args.savePreTrainedWeightsFloat + or args.runPrediction + or args.saveImgAndWtData + ): + modelPath = "./PreTrainedModel/tf-densenet121.ckpt" + saver = tf.train.Saver() + saver.restore(sess, modelPath) + if ( + args.savePreTrainedWeightsInt + or args.savePreTrainedWeightsFloat + or args.saveImgAndWtData + ): + DumpTFMtData.updateWeightsForBN(optimized_graph_def, sess, feed_dict) + + predictions = None + if args.runPrediction: + print("*************** Starting Prediction****************") + start_time = time.time() + predictions = sess.run(output_tensor, feed_dict=feed_dict) + end_time = time.time() + print("*************** Done Prediction****************") + duration = end_time - start_time + print("Time taken in inference : ", duration) + with open("tf_pred.float", "w+") as f: + f.write(DumpTFMtData.numpy_float_array_to_float_val_str(predictions)) + with open("tf_pred.time", "w") as f: + f.write(str(round(duration, 2))) + + print("Prediction = ", predictions) + + trainVarsName = [] + for node in optimized_graph_def.node: + if node.op == "VariableV2": + trainVarsName.append(node.name) + trainVars = list( + map( + lambda x: tf.get_default_graph().get_operation_by_name(x).outputs[0], + trainVarsName, + ) + ) + if args.savePreTrainedWeightsInt: + DumpTFMtData.dumpTrainedWeightsInt( + sess, + trainVars, + "model_weights_scale_{}.inp".format(args.scalingFac), + args.scalingFac, + "w", + ) + if args.savePreTrainedWeightsFloat: + DumpTFMtData.dumpTrainedWeightsFloat( + sess, trainVars, "model_weights_float.inp", "w" + ) + if args.saveImgAndWtData: + DumpTFMtData.dumpImgAndWeightsDataSeparate( + sess, + images[0], + trainVars, + "model_input_scale_{}.inp".format(args.scalingFac), + "model_weights_scale_{}.inp".format(args.scalingFac), + args.scalingFac, + ) diff --git a/Athos/Networks/DenseNet/PreProcessingImages/DenseNet_preprocess_main.py b/Athos/Networks/DenseNet/PreProcessingImages/DenseNet_preprocess_main.py index 211084a2..fe83c176 100644 --- a/Athos/Networks/DenseNet/PreProcessingImages/DenseNet_preprocess_main.py +++ b/Athos/Networks/DenseNet/PreProcessingImages/DenseNet_preprocess_main.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -22,131 +22,170 @@ Part of the code in this file is taken from https://github.com/pudae/tensorflow-densenet. -''' +""" import os, sys, numpy import _pickle as pickle import tensorflow as tf import DenseNet_preprocessing + class ImageCoder(object): - """Helper class that provides TensorFlow image coding utilities.""" + """Helper class that provides TensorFlow image coding utilities.""" + + def __init__(self): + # Create a single Session to run all image coding calls. + self._sess = tf.Session() - def __init__(self): - # Create a single Session to run all image coding calls. - self._sess = tf.Session() + # Initializes function that converts PNG to JPEG data. + self._png_data = tf.placeholder(dtype=tf.string) + image = tf.image.decode_png(self._png_data, channels=3) + self._png_to_jpeg = tf.image.encode_jpeg(image, format="rgb", quality=100) - # Initializes function that converts PNG to JPEG data. - self._png_data = tf.placeholder(dtype=tf.string) - image = tf.image.decode_png(self._png_data, channels=3) - self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100) + # Initializes function that converts CMYK JPEG data to RGB JPEG data. + self._cmyk_data = tf.placeholder(dtype=tf.string) + image = tf.image.decode_jpeg(self._cmyk_data, channels=0) + self._cmyk_to_rgb = tf.image.encode_jpeg(image, format="rgb", quality=100) - # Initializes function that converts CMYK JPEG data to RGB JPEG data. - self._cmyk_data = tf.placeholder(dtype=tf.string) - image = tf.image.decode_jpeg(self._cmyk_data, channels=0) - self._cmyk_to_rgb = tf.image.encode_jpeg(image, format='rgb', quality=100) + # Initializes function that decodes RGB JPEG data. + self._decode_jpeg_data = tf.placeholder(dtype=tf.string) + self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) - # Initializes function that decodes RGB JPEG data. - self._decode_jpeg_data = tf.placeholder(dtype=tf.string) - self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) + def png_to_jpeg(self, image_data): + return self._sess.run(self._png_to_jpeg, feed_dict={self._png_data: image_data}) - def png_to_jpeg(self, image_data): - return self._sess.run(self._png_to_jpeg, - feed_dict={self._png_data: image_data}) + def cmyk_to_rgb(self, image_data): + return self._sess.run( + self._cmyk_to_rgb, feed_dict={self._cmyk_data: image_data} + ) - def cmyk_to_rgb(self, image_data): - return self._sess.run(self._cmyk_to_rgb, - feed_dict={self._cmyk_data: image_data}) + def decode_jpeg(self, image_data): + image = self._sess.run( + self._decode_jpeg, feed_dict={self._decode_jpeg_data: image_data} + ) + assert len(image.shape) == 3 + assert image.shape[2] == 3 + return image - def decode_jpeg(self, image_data): - image = self._sess.run(self._decode_jpeg, - feed_dict={self._decode_jpeg_data: image_data}) + +def _process_image(filename, coder): + """Process a single image file. + + Args: + filename: string, path to an image file e.g., '/path/to/example.JPG'. + coder: instance of ImageCoder to provide TensorFlow image coding utils. + Returns: + image_buffer: string, JPEG encoding of RGB image. + height: integer, image height in pixels. + width: integer, image width in pixels. + """ + # Read the image file. + with tf.gfile.GFile(filename, "rb") as f: + image_data = f.read() + + # Decode the RGB JPEG. + image = coder.decode_jpeg(image_data) + + # Check that image converted to RGB assert len(image.shape) == 3 + height = image.shape[0] + width = image.shape[1] assert image.shape[2] == 3 - return image -def _process_image(filename, coder): - """Process a single image file. - - Args: - filename: string, path to an image file e.g., '/path/to/example.JPG'. - coder: instance of ImageCoder to provide TensorFlow image coding utils. - Returns: - image_buffer: string, JPEG encoding of RGB image. - height: integer, image height in pixels. - width: integer, image width in pixels. - """ - # Read the image file. - with tf.gfile.GFile(filename, 'rb') as f: - image_data = f.read() - - # Decode the RGB JPEG. - image = coder.decode_jpeg(image_data) - - # Check that image converted to RGB - assert len(image.shape) == 3 - height = image.shape[0] - width = image.shape[1] - assert image.shape[2] == 3 - - return image, height, width + return image, height, width + def dumpImageDataFloat(imgData, filename, writeMode): - with open(filename, writeMode) as ff: - for xx in numpy.nditer(imgData, order='C'): - ff.write(str(xx) + ' ') - ff.write('\n\n') + with open(filename, writeMode) as ff: + for xx in numpy.nditer(imgData, order="C"): + ff.write(str(xx) + " ") + ff.write("\n\n") + def main(): - if not((len(sys.argv) >= 7) and (len(sys.argv) <= 8)): - print("Args : ?", file=sys.stderr) - exit(1) - - imgFolderName = sys.argv[1] - bboxFolderName = sys.argv[2] - fileNamePrefix = sys.argv[3] - preProcessedImgFolderName = sys.argv[4] - firstImgNum = int(sys.argv[5]) - lastImgNum = int(sys.argv[6]) - randomSubsetIdxFile = None - if (len(sys.argv) == 8): - randomSubsetIdxFile = sys.argv[7] - - randomIdxToBeChosen = None - if randomSubsetIdxFile: - with open(randomSubsetIdxFile, 'r') as ff: - randomIdxToBeChosen = ff.readlines() - randomIdxToBeChosen = list(map(lambda x : int(x.rstrip()), randomIdxToBeChosen)) - assert(lastImgNum <= len(randomIdxToBeChosen)+1) #Assert that the last img num passed is within bounds - - def helper_img(curImgNum): - actualImgNum = curImgNum if not(randomIdxToBeChosen) else randomIdxToBeChosen[curImgNum-1] - saveFilePath = os.path.join(preProcessedImgFolderName, 'ImageNum_' + str(actualImgNum) + '.inp') - if (os.path.exists(saveFilePath)): - print("Preprocessed file already exists. Skipping. Img Num = {0}".format(actualImgNum)) - return - imgFileName = os.path.join(imgFolderName, fileNamePrefix + "{:08d}".format(actualImgNum) + '.JPEG') - image_buffer, height, width = _process_image(imgFileName, coder) - preprocessed_image_buffer = sess.run(DenseNet_preprocessing.preprocess_image(image_buffer, 224, 224)) - - dumpImageDataFloat(preprocessed_image_buffer, saveFilePath, 'w') - - pid = os.getpid() - print("PID={0}, firstImgNum={1}, lastImgNum={2}: Starting with first image.".format(pid, firstImgNum, lastImgNum)) - with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - coder = ImageCoder() - for curImgNum in range(firstImgNum, (firstImgNum + lastImgNum)//2): - helper_img(curImgNum) - - print("PID={0}, firstImgNum={1}, lastImgNum={2}: Crossed half mark.".format(pid, firstImgNum, lastImgNum)) - with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - coder = ImageCoder() - for curImgNum in range((firstImgNum + lastImgNum)//2, lastImgNum): - helper_img(curImgNum) - - print("PID={0}, firstImgNum={1}, lastImgNum={2}: All images done.".format(pid, firstImgNum, lastImgNum)) - -if __name__=='__main__': - main() + if not ((len(sys.argv) >= 7) and (len(sys.argv) <= 8)): + print( + "Args : ?", + file=sys.stderr, + ) + exit(1) + + imgFolderName = sys.argv[1] + bboxFolderName = sys.argv[2] + fileNamePrefix = sys.argv[3] + preProcessedImgFolderName = sys.argv[4] + firstImgNum = int(sys.argv[5]) + lastImgNum = int(sys.argv[6]) + randomSubsetIdxFile = None + if len(sys.argv) == 8: + randomSubsetIdxFile = sys.argv[7] + + randomIdxToBeChosen = None + if randomSubsetIdxFile: + with open(randomSubsetIdxFile, "r") as ff: + randomIdxToBeChosen = ff.readlines() + randomIdxToBeChosen = list( + map(lambda x: int(x.rstrip()), randomIdxToBeChosen) + ) + assert ( + lastImgNum <= len(randomIdxToBeChosen) + 1 + ) # Assert that the last img num passed is within bounds + + def helper_img(curImgNum): + actualImgNum = ( + curImgNum + if not (randomIdxToBeChosen) + else randomIdxToBeChosen[curImgNum - 1] + ) + saveFilePath = os.path.join( + preProcessedImgFolderName, "ImageNum_" + str(actualImgNum) + ".inp" + ) + if os.path.exists(saveFilePath): + print( + "Preprocessed file already exists. Skipping. Img Num = {0}".format( + actualImgNum + ) + ) + return + imgFileName = os.path.join( + imgFolderName, fileNamePrefix + "{:08d}".format(actualImgNum) + ".JPEG" + ) + image_buffer, height, width = _process_image(imgFileName, coder) + preprocessed_image_buffer = sess.run( + DenseNet_preprocessing.preprocess_image(image_buffer, 224, 224) + ) + + dumpImageDataFloat(preprocessed_image_buffer, saveFilePath, "w") + + pid = os.getpid() + print( + "PID={0}, firstImgNum={1}, lastImgNum={2}: Starting with first image.".format( + pid, firstImgNum, lastImgNum + ) + ) + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + coder = ImageCoder() + for curImgNum in range(firstImgNum, (firstImgNum + lastImgNum) // 2): + helper_img(curImgNum) + + print( + "PID={0}, firstImgNum={1}, lastImgNum={2}: Crossed half mark.".format( + pid, firstImgNum, lastImgNum + ) + ) + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + coder = ImageCoder() + for curImgNum in range((firstImgNum + lastImgNum) // 2, lastImgNum): + helper_img(curImgNum) + + print( + "PID={0}, firstImgNum={1}, lastImgNum={2}: All images done.".format( + pid, firstImgNum, lastImgNum + ) + ) + + +if __name__ == "__main__": + main() diff --git a/Athos/Networks/DenseNet/PreProcessingImages/DenseNet_preprocessing.py b/Athos/Networks/DenseNet/PreProcessingImages/DenseNet_preprocessing.py index 5b3ff8c8..7598c726 100644 --- a/Athos/Networks/DenseNet/PreProcessingImages/DenseNet_preprocessing.py +++ b/Athos/Networks/DenseNet/PreProcessingImages/DenseNet_preprocessing.py @@ -47,325 +47,356 @@ def _crop(image, offset_height, offset_width, crop_height, crop_width): - """Crops the given image using the provided offsets and sizes. - - Note that the method doesn't assume we know the input image size but it does - assume we know the input image rank. - - Args: - image: an image of shape [height, width, channels]. - offset_height: a scalar tensor indicating the height offset. - offset_width: a scalar tensor indicating the width offset. - crop_height: the height of the cropped image. - crop_width: the width of the cropped image. - - Returns: - the cropped (and resized) image. - - Raises: - InvalidArgumentError: if the rank is not 3 or if the image dimensions are - less than the crop size. - """ - original_shape = tf.shape(image) - - rank_assertion = tf.Assert( - tf.equal(tf.rank(image), 3), - ['Rank of image must be equal to 3.']) - with tf.control_dependencies([rank_assertion]): - cropped_shape = tf.stack([crop_height, crop_width, original_shape[2]]) - - size_assertion = tf.Assert( - tf.logical_and( - tf.greater_equal(original_shape[0], crop_height), - tf.greater_equal(original_shape[1], crop_width)), - ['Crop size greater than the image size.']) - - offsets = tf.to_int32(tf.stack([offset_height, offset_width, 0])) - - # Use tf.slice instead of crop_to_bounding box as it accepts tensors to - # define the crop size. - with tf.control_dependencies([size_assertion]): - image = tf.slice(image, offsets, cropped_shape) - return tf.reshape(image, cropped_shape) + """Crops the given image using the provided offsets and sizes. + + Note that the method doesn't assume we know the input image size but it does + assume we know the input image rank. + + Args: + image: an image of shape [height, width, channels]. + offset_height: a scalar tensor indicating the height offset. + offset_width: a scalar tensor indicating the width offset. + crop_height: the height of the cropped image. + crop_width: the width of the cropped image. + + Returns: + the cropped (and resized) image. + + Raises: + InvalidArgumentError: if the rank is not 3 or if the image dimensions are + less than the crop size. + """ + original_shape = tf.shape(image) + + rank_assertion = tf.Assert( + tf.equal(tf.rank(image), 3), ["Rank of image must be equal to 3."] + ) + with tf.control_dependencies([rank_assertion]): + cropped_shape = tf.stack([crop_height, crop_width, original_shape[2]]) + + size_assertion = tf.Assert( + tf.logical_and( + tf.greater_equal(original_shape[0], crop_height), + tf.greater_equal(original_shape[1], crop_width), + ), + ["Crop size greater than the image size."], + ) + + offsets = tf.to_int32(tf.stack([offset_height, offset_width, 0])) + + # Use tf.slice instead of crop_to_bounding box as it accepts tensors to + # define the crop size. + with tf.control_dependencies([size_assertion]): + image = tf.slice(image, offsets, cropped_shape) + return tf.reshape(image, cropped_shape) def _random_crop(image_list, crop_height, crop_width): - """Crops the given list of images. - - The function applies the same crop to each image in the list. This can be - effectively applied when there are multiple image inputs of the same - dimension such as: - - image, depths, normals = _random_crop([image, depths, normals], 120, 150) - - Args: - image_list: a list of image tensors of the same dimension but possibly - varying channel. - crop_height: the new height. - crop_width: the new width. - - Returns: - the image_list with cropped images. - - Raises: - ValueError: if there are multiple image inputs provided with different size - or the images are smaller than the crop dimensions. - """ - if not image_list: - raise ValueError('Empty image_list.') - - # Compute the rank assertions. - rank_assertions = [] - for i in range(len(image_list)): - image_rank = tf.rank(image_list[i]) - rank_assert = tf.Assert( - tf.equal(image_rank, 3), - ['Wrong rank for tensor %s [expected] [actual]', - image_list[i].name, 3, image_rank]) - rank_assertions.append(rank_assert) - - with tf.control_dependencies([rank_assertions[0]]): - image_shape = tf.shape(image_list[0]) - image_height = image_shape[0] - image_width = image_shape[1] - crop_size_assert = tf.Assert( - tf.logical_and( - tf.greater_equal(image_height, crop_height), - tf.greater_equal(image_width, crop_width)), - ['Crop size greater than the image size.']) - - asserts = [rank_assertions[0], crop_size_assert] - - for i in range(1, len(image_list)): - image = image_list[i] - asserts.append(rank_assertions[i]) - with tf.control_dependencies([rank_assertions[i]]): - shape = tf.shape(image) - height = shape[0] - width = shape[1] - - height_assert = tf.Assert( - tf.equal(height, image_height), - ['Wrong height for tensor %s [expected][actual]', - image.name, height, image_height]) - width_assert = tf.Assert( - tf.equal(width, image_width), - ['Wrong width for tensor %s [expected][actual]', - image.name, width, image_width]) - asserts.extend([height_assert, width_assert]) - - # Create a random bounding box. - # - # Use tf.random_uniform and not numpy.random.rand as doing the former would - # generate random numbers at graph eval time, unlike the latter which - # generates random numbers at graph definition time. - with tf.control_dependencies(asserts): - max_offset_height = tf.reshape(image_height - crop_height + 1, []) - with tf.control_dependencies(asserts): - max_offset_width = tf.reshape(image_width - crop_width + 1, []) - offset_height = tf.random_uniform( - [], maxval=max_offset_height, dtype=tf.int32) - offset_width = tf.random_uniform( - [], maxval=max_offset_width, dtype=tf.int32) - - return [_crop(image, offset_height, offset_width, - crop_height, crop_width) for image in image_list] + """Crops the given list of images. + + The function applies the same crop to each image in the list. This can be + effectively applied when there are multiple image inputs of the same + dimension such as: + + image, depths, normals = _random_crop([image, depths, normals], 120, 150) + + Args: + image_list: a list of image tensors of the same dimension but possibly + varying channel. + crop_height: the new height. + crop_width: the new width. + + Returns: + the image_list with cropped images. + + Raises: + ValueError: if there are multiple image inputs provided with different size + or the images are smaller than the crop dimensions. + """ + if not image_list: + raise ValueError("Empty image_list.") + + # Compute the rank assertions. + rank_assertions = [] + for i in range(len(image_list)): + image_rank = tf.rank(image_list[i]) + rank_assert = tf.Assert( + tf.equal(image_rank, 3), + [ + "Wrong rank for tensor %s [expected] [actual]", + image_list[i].name, + 3, + image_rank, + ], + ) + rank_assertions.append(rank_assert) + + with tf.control_dependencies([rank_assertions[0]]): + image_shape = tf.shape(image_list[0]) + image_height = image_shape[0] + image_width = image_shape[1] + crop_size_assert = tf.Assert( + tf.logical_and( + tf.greater_equal(image_height, crop_height), + tf.greater_equal(image_width, crop_width), + ), + ["Crop size greater than the image size."], + ) + + asserts = [rank_assertions[0], crop_size_assert] + + for i in range(1, len(image_list)): + image = image_list[i] + asserts.append(rank_assertions[i]) + with tf.control_dependencies([rank_assertions[i]]): + shape = tf.shape(image) + height = shape[0] + width = shape[1] + + height_assert = tf.Assert( + tf.equal(height, image_height), + [ + "Wrong height for tensor %s [expected][actual]", + image.name, + height, + image_height, + ], + ) + width_assert = tf.Assert( + tf.equal(width, image_width), + [ + "Wrong width for tensor %s [expected][actual]", + image.name, + width, + image_width, + ], + ) + asserts.extend([height_assert, width_assert]) + + # Create a random bounding box. + # + # Use tf.random_uniform and not numpy.random.rand as doing the former would + # generate random numbers at graph eval time, unlike the latter which + # generates random numbers at graph definition time. + with tf.control_dependencies(asserts): + max_offset_height = tf.reshape(image_height - crop_height + 1, []) + with tf.control_dependencies(asserts): + max_offset_width = tf.reshape(image_width - crop_width + 1, []) + offset_height = tf.random_uniform([], maxval=max_offset_height, dtype=tf.int32) + offset_width = tf.random_uniform([], maxval=max_offset_width, dtype=tf.int32) + + return [ + _crop(image, offset_height, offset_width, crop_height, crop_width) + for image in image_list + ] def _central_crop(image_list, crop_height, crop_width): - """Performs central crops of the given image list. + """Performs central crops of the given image list. - Args: - image_list: a list of image tensors of the same dimension but possibly - varying channel. - crop_height: the height of the image following the crop. - crop_width: the width of the image following the crop. + Args: + image_list: a list of image tensors of the same dimension but possibly + varying channel. + crop_height: the height of the image following the crop. + crop_width: the width of the image following the crop. - Returns: - the list of cropped images. - """ - outputs = [] - for image in image_list: - image_height = tf.shape(image)[0] - image_width = tf.shape(image)[1] + Returns: + the list of cropped images. + """ + outputs = [] + for image in image_list: + image_height = tf.shape(image)[0] + image_width = tf.shape(image)[1] - offset_height = (image_height - crop_height) / 2 - offset_width = (image_width - crop_width) / 2 + offset_height = (image_height - crop_height) / 2 + offset_width = (image_width - crop_width) / 2 - outputs.append(_crop(image, offset_height, offset_width, - crop_height, crop_width)) - return outputs + outputs.append( + _crop(image, offset_height, offset_width, crop_height, crop_width) + ) + return outputs def _mean_image_subtraction(image, means): - """Subtracts the given means from each image channel. + """Subtracts the given means from each image channel. - For example: - means = [123.68, 116.779, 103.939] - image = _mean_image_subtraction(image, means) + For example: + means = [123.68, 116.779, 103.939] + image = _mean_image_subtraction(image, means) - Note that the rank of `image` must be known. + Note that the rank of `image` must be known. - Args: - image: a tensor of size [height, width, C]. - means: a C-vector of values to subtract from each channel. + Args: + image: a tensor of size [height, width, C]. + means: a C-vector of values to subtract from each channel. - Returns: - the centered image. + Returns: + the centered image. - Raises: - ValueError: If the rank of `image` is unknown, if `image` has a rank other - than three or if the number of channels in `image` doesn't match the - number of values in `means`. - """ - if image.get_shape().ndims != 3: - raise ValueError('Input must be of size [height, width, C>0]') - num_channels = image.get_shape().as_list()[-1] - if len(means) != num_channels: - raise ValueError('len(means) must match the number of channels') + Raises: + ValueError: If the rank of `image` is unknown, if `image` has a rank other + than three or if the number of channels in `image` doesn't match the + number of values in `means`. + """ + if image.get_shape().ndims != 3: + raise ValueError("Input must be of size [height, width, C>0]") + num_channels = image.get_shape().as_list()[-1] + if len(means) != num_channels: + raise ValueError("len(means) must match the number of channels") - channels = tf.split(axis=2, num_or_size_splits=num_channels, value=image) - for i in range(num_channels): - channels[i] -= means[i] - return tf.concat(axis=2, values=channels) + channels = tf.split(axis=2, num_or_size_splits=num_channels, value=image) + for i in range(num_channels): + channels[i] -= means[i] + return tf.concat(axis=2, values=channels) def _smallest_size_at_least(height, width, smallest_side): - """Computes new shape with the smallest side equal to `smallest_side`. + """Computes new shape with the smallest side equal to `smallest_side`. - Computes new shape with the smallest side equal to `smallest_side` while - preserving the original aspect ratio. + Computes new shape with the smallest side equal to `smallest_side` while + preserving the original aspect ratio. - Args: - height: an int32 scalar tensor indicating the current height. - width: an int32 scalar tensor indicating the current width. - smallest_side: A python integer or scalar `Tensor` indicating the size of - the smallest side after resize. + Args: + height: an int32 scalar tensor indicating the current height. + width: an int32 scalar tensor indicating the current width. + smallest_side: A python integer or scalar `Tensor` indicating the size of + the smallest side after resize. - Returns: - new_height: an int32 scalar tensor indicating the new height. - new_width: and int32 scalar tensor indicating the new width. - """ - smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32) + Returns: + new_height: an int32 scalar tensor indicating the new height. + new_width: and int32 scalar tensor indicating the new width. + """ + smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32) - height = tf.to_float(height) - width = tf.to_float(width) - smallest_side = tf.to_float(smallest_side) + height = tf.to_float(height) + width = tf.to_float(width) + smallest_side = tf.to_float(smallest_side) - scale = tf.cond(tf.greater(height, width), - lambda: smallest_side / width, - lambda: smallest_side / height) - new_height = tf.to_int32(height * scale) - new_width = tf.to_int32(width * scale) - return new_height, new_width + scale = tf.cond( + tf.greater(height, width), + lambda: smallest_side / width, + lambda: smallest_side / height, + ) + new_height = tf.to_int32(height * scale) + new_width = tf.to_int32(width * scale) + return new_height, new_width def _aspect_preserving_resize(image, smallest_side): - """Resize images preserving the original aspect ratio. - - Args: - image: A 3-D image `Tensor`. - smallest_side: A python integer or scalar `Tensor` indicating the size of - the smallest side after resize. - - Returns: - resized_image: A 3-D tensor containing the resized image. - """ - smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32) - - shape = tf.shape(image) - height = shape[0] - width = shape[1] - new_height, new_width = _smallest_size_at_least(height, width, smallest_side) - image = tf.expand_dims(image, 0) - resized_image = tf.image.resize_bilinear(image, [new_height, new_width], - align_corners=False) - resized_image = tf.squeeze(resized_image) - resized_image.set_shape([None, None, 3]) - return resized_image - - -def preprocess_for_train(image, - output_height, - output_width, - resize_side_min=_RESIZE_SIDE_MIN, - resize_side_max=_RESIZE_SIDE_MAX): - """Preprocesses the given image for training. - - Note that the actual resizing scale is sampled from - [`resize_size_min`, `resize_size_max`]. - - Args: - image: A `Tensor` representing an image of arbitrary size. - output_height: The height of the image after preprocessing. - output_width: The width of the image after preprocessing. - resize_side_min: The lower bound for the smallest side of the image for - aspect-preserving resizing. - resize_side_max: The upper bound for the smallest side of the image for - aspect-preserving resizing. - - Returns: - A preprocessed image. - """ - resize_side = tf.random_uniform( - [], minval=resize_side_min, maxval=resize_side_max+1, dtype=tf.int32) - - image = _aspect_preserving_resize(image, resize_side) - image = _random_crop([image], output_height, output_width)[0] - image.set_shape([output_height, output_width, 3]) - image = tf.to_float(image) - image = tf.image.random_flip_left_right(image) - - image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]) - return image * _SCALE_FACTOR + """Resize images preserving the original aspect ratio. + + Args: + image: A 3-D image `Tensor`. + smallest_side: A python integer or scalar `Tensor` indicating the size of + the smallest side after resize. + + Returns: + resized_image: A 3-D tensor containing the resized image. + """ + smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32) + + shape = tf.shape(image) + height = shape[0] + width = shape[1] + new_height, new_width = _smallest_size_at_least(height, width, smallest_side) + image = tf.expand_dims(image, 0) + resized_image = tf.image.resize_bilinear( + image, [new_height, new_width], align_corners=False + ) + resized_image = tf.squeeze(resized_image) + resized_image.set_shape([None, None, 3]) + return resized_image + + +def preprocess_for_train( + image, + output_height, + output_width, + resize_side_min=_RESIZE_SIDE_MIN, + resize_side_max=_RESIZE_SIDE_MAX, +): + """Preprocesses the given image for training. + + Note that the actual resizing scale is sampled from + [`resize_size_min`, `resize_size_max`]. + + Args: + image: A `Tensor` representing an image of arbitrary size. + output_height: The height of the image after preprocessing. + output_width: The width of the image after preprocessing. + resize_side_min: The lower bound for the smallest side of the image for + aspect-preserving resizing. + resize_side_max: The upper bound for the smallest side of the image for + aspect-preserving resizing. + + Returns: + A preprocessed image. + """ + resize_side = tf.random_uniform( + [], minval=resize_side_min, maxval=resize_side_max + 1, dtype=tf.int32 + ) + + image = _aspect_preserving_resize(image, resize_side) + image = _random_crop([image], output_height, output_width)[0] + image.set_shape([output_height, output_width, 3]) + image = tf.to_float(image) + image = tf.image.random_flip_left_right(image) + + image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]) + return image * _SCALE_FACTOR def preprocess_for_eval(image, output_height, output_width, resize_side): - """Preprocesses the given image for evaluation. - - Args: - image: A `Tensor` representing an image of arbitrary size. - output_height: The height of the image after preprocessing. - output_width: The width of the image after preprocessing. - resize_side: The smallest side of the image for aspect-preserving resizing. - - Returns: - A preprocessed image. - """ - image = _aspect_preserving_resize(image, resize_side) - image = _central_crop([image], output_height, output_width)[0] - image.set_shape([output_height, output_width, 3]) - image = tf.to_float(image) - - image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]) - return image * _SCALE_FACTOR - - -def preprocess_image(image, output_height, output_width, is_training=False, - resize_side_min=_RESIZE_SIDE_MIN, - resize_side_max=_RESIZE_SIDE_MAX): - """Preprocesses the given image. - - Args: - image: A `Tensor` representing an image of arbitrary size. - output_height: The height of the image after preprocessing. - output_width: The width of the image after preprocessing. - is_training: `True` if we're preprocessing the image for training and - `False` otherwise. - resize_side_min: The lower bound for the smallest side of the image for - aspect-preserving resizing. If `is_training` is `False`, then this value - is used for rescaling. - resize_side_max: The upper bound for the smallest side of the image for - aspect-preserving resizing. If `is_training` is `False`, this value is - ignored. Otherwise, the resize side is sampled from - [resize_size_min, resize_size_max]. - - Returns: - A preprocessed image. - """ - if is_training: - return preprocess_for_train(image, output_height, output_width, - resize_side_min, resize_side_max) - else: - return preprocess_for_eval(image, output_height, output_width, - resize_side_min) + """Preprocesses the given image for evaluation. + + Args: + image: A `Tensor` representing an image of arbitrary size. + output_height: The height of the image after preprocessing. + output_width: The width of the image after preprocessing. + resize_side: The smallest side of the image for aspect-preserving resizing. + + Returns: + A preprocessed image. + """ + image = _aspect_preserving_resize(image, resize_side) + image = _central_crop([image], output_height, output_width)[0] + image.set_shape([output_height, output_width, 3]) + image = tf.to_float(image) + + image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]) + return image * _SCALE_FACTOR + + +def preprocess_image( + image, + output_height, + output_width, + is_training=False, + resize_side_min=_RESIZE_SIDE_MIN, + resize_side_max=_RESIZE_SIDE_MAX, +): + """Preprocesses the given image. + + Args: + image: A `Tensor` representing an image of arbitrary size. + output_height: The height of the image after preprocessing. + output_width: The width of the image after preprocessing. + is_training: `True` if we're preprocessing the image for training and + `False` otherwise. + resize_side_min: The lower bound for the smallest side of the image for + aspect-preserving resizing. If `is_training` is `False`, then this value + is used for rescaling. + resize_side_max: The upper bound for the smallest side of the image for + aspect-preserving resizing. If `is_training` is `False`, this value is + ignored. Otherwise, the resize side is sampled from + [resize_size_min, resize_size_max]. + + Returns: + A preprocessed image. + """ + if is_training: + return preprocess_for_train( + image, output_height, output_width, resize_side_min, resize_side_max + ) + else: + return preprocess_for_eval(image, output_height, output_width, resize_side_min) diff --git a/Athos/Networks/DenseNet/densenet.py b/Athos/Networks/DenseNet/densenet.py index c1cef813..52f75e2c 100644 --- a/Athos/Networks/DenseNet/densenet.py +++ b/Athos/Networks/DenseNet/densenet.py @@ -29,214 +29,283 @@ @slim.add_arg_scope -def _global_avg_pool2d(inputs, data_format='NHWC', scope=None, outputs_collections=None): - with tf.variable_scope(scope, 'xx', [inputs]) as sc: - axis = [1, 2] if data_format == 'NHWC' else [2, 3] - # net = tf.reduce_mean(inputs, axis=axis, keep_dims=True) - net = tf.nn.avg_pool(inputs, ksize=[1,7,7,1], strides=[1,1,1,1], padding='VALID') - net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) - return net +def _global_avg_pool2d( + inputs, data_format="NHWC", scope=None, outputs_collections=None +): + with tf.variable_scope(scope, "xx", [inputs]) as sc: + axis = [1, 2] if data_format == "NHWC" else [2, 3] + # net = tf.reduce_mean(inputs, axis=axis, keep_dims=True) + net = tf.nn.avg_pool( + inputs, ksize=[1, 7, 7, 1], strides=[1, 1, 1, 1], padding="VALID" + ) + net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) + return net @slim.add_arg_scope -def _conv(inputs, num_filters, kernel_size, stride=1, dropout_rate=None, - scope=None, outputs_collections=None): - with tf.variable_scope(scope, 'xx', [inputs]) as sc: - net = slim.batch_norm(inputs) - net = tf.nn.relu(net) - net = slim.conv2d(net, num_filters, kernel_size) +def _conv( + inputs, + num_filters, + kernel_size, + stride=1, + dropout_rate=None, + scope=None, + outputs_collections=None, +): + with tf.variable_scope(scope, "xx", [inputs]) as sc: + net = slim.batch_norm(inputs) + net = tf.nn.relu(net) + net = slim.conv2d(net, num_filters, kernel_size) - if dropout_rate: - net = tf.nn.dropout(net) + if dropout_rate: + net = tf.nn.dropout(net) - net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) + net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) - return net + return net @slim.add_arg_scope -def _conv_block(inputs, num_filters, data_format='NHWC', scope=None, outputs_collections=None): - with tf.variable_scope(scope, 'conv_blockx', [inputs]) as sc: - net = inputs - net = _conv(net, num_filters*4, 1, scope='x1') - net = _conv(net, num_filters, 3, scope='x2') - if data_format == 'NHWC': - net = tf.concat([inputs, net], axis=3) - else: # "NCHW" - net = tf.concat([inputs, net], axis=1) +def _conv_block( + inputs, num_filters, data_format="NHWC", scope=None, outputs_collections=None +): + with tf.variable_scope(scope, "conv_blockx", [inputs]) as sc: + net = inputs + net = _conv(net, num_filters * 4, 1, scope="x1") + net = _conv(net, num_filters, 3, scope="x2") + if data_format == "NHWC": + net = tf.concat([inputs, net], axis=3) + else: # "NCHW" + net = tf.concat([inputs, net], axis=1) + + net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) - net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) - - return net + return net @slim.add_arg_scope -def _dense_block(inputs, num_layers, num_filters, growth_rate, - grow_num_filters=True, scope=None, outputs_collections=None): +def _dense_block( + inputs, + num_layers, + num_filters, + growth_rate, + grow_num_filters=True, + scope=None, + outputs_collections=None, +): - with tf.variable_scope(scope, 'dense_blockx', [inputs]) as sc: - net = inputs - for i in range(num_layers): - branch = i + 1 - net = _conv_block(net, growth_rate, scope='conv_block'+str(branch)) + with tf.variable_scope(scope, "dense_blockx", [inputs]) as sc: + net = inputs + for i in range(num_layers): + branch = i + 1 + net = _conv_block(net, growth_rate, scope="conv_block" + str(branch)) - if grow_num_filters: - num_filters += growth_rate + if grow_num_filters: + num_filters += growth_rate - net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) + net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) - return net, num_filters + return net, num_filters @slim.add_arg_scope -def _transition_block(inputs, num_filters, compression=1.0, - scope=None, outputs_collections=None): - - num_filters = int(num_filters * compression) - with tf.variable_scope(scope, 'transition_blockx', [inputs]) as sc: - net = inputs - net = _conv(net, num_filters, 1, scope='blk') - - net = slim.avg_pool2d(net, 2) - - net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) - - return net, num_filters - - -def densenet(inputs, - num_classes=1000, - reduction=None, - growth_rate=None, - num_filters=None, - num_layers=None, - dropout_rate=None, - data_format='NHWC', - is_training=True, - reuse=None, - scope=None): - assert reduction is not None - assert growth_rate is not None - assert num_filters is not None - assert num_layers is not None - - compression = 1.0 - reduction - num_dense_blocks = len(num_layers) - - if data_format == 'NCHW': - inputs = tf.transpose(inputs, [0, 3, 1, 2]) - - with tf.variable_scope(scope, 'densenetxxx', [inputs, num_classes], - reuse=reuse) as sc: - end_points_collection = sc.name + '_end_points' - with slim.arg_scope([slim.batch_norm, slim.dropout], - is_training=is_training), \ - slim.arg_scope([slim.conv2d, _conv, _conv_block, - _dense_block, _transition_block], - outputs_collections=end_points_collection), \ - slim.arg_scope([_conv], dropout_rate=dropout_rate): - net = inputs - - # initial convolution - net = slim.conv2d(net, num_filters, 7, stride=2, scope='conv1') - net = slim.batch_norm(net) - net = tf.nn.relu(net) - net = slim.max_pool2d(net, 3, stride=2, padding='SAME') - - # blocks - for i in range(num_dense_blocks - 1): - # dense blocks - net, num_filters = _dense_block(net, num_layers[i], num_filters, - growth_rate, - scope='dense_block' + str(i+1)) - - # Add transition_block - net, num_filters = _transition_block(net, num_filters, - compression=compression, - scope='transition_block' + str(i+1)) - - net, num_filters = _dense_block( - net, num_layers[-1], num_filters, - growth_rate, - scope='dense_block' + str(num_dense_blocks)) - - # final blocks - with tf.variable_scope('final_block', [inputs]): - net = slim.batch_norm(net) - net = tf.nn.relu(net) - net = _global_avg_pool2d(net, scope='global_avg_pool') +def _transition_block( + inputs, num_filters, compression=1.0, scope=None, outputs_collections=None +): + + num_filters = int(num_filters * compression) + with tf.variable_scope(scope, "transition_blockx", [inputs]) as sc: + net = inputs + net = _conv(net, num_filters, 1, scope="blk") + + net = slim.avg_pool2d(net, 2) + + net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) + + return net, num_filters + + +def densenet( + inputs, + num_classes=1000, + reduction=None, + growth_rate=None, + num_filters=None, + num_layers=None, + dropout_rate=None, + data_format="NHWC", + is_training=True, + reuse=None, + scope=None, +): + assert reduction is not None + assert growth_rate is not None + assert num_filters is not None + assert num_layers is not None + + compression = 1.0 - reduction + num_dense_blocks = len(num_layers) + + if data_format == "NCHW": + inputs = tf.transpose(inputs, [0, 3, 1, 2]) + + with tf.variable_scope( + scope, "densenetxxx", [inputs, num_classes], reuse=reuse + ) as sc: + end_points_collection = sc.name + "_end_points" + with slim.arg_scope( + [slim.batch_norm, slim.dropout], is_training=is_training + ), slim.arg_scope( + [slim.conv2d, _conv, _conv_block, _dense_block, _transition_block], + outputs_collections=end_points_collection, + ), slim.arg_scope( + [_conv], dropout_rate=dropout_rate + ): + net = inputs + + # initial convolution + net = slim.conv2d(net, num_filters, 7, stride=2, scope="conv1") + net = slim.batch_norm(net) + net = tf.nn.relu(net) + net = slim.max_pool2d(net, 3, stride=2, padding="SAME") + + # blocks + for i in range(num_dense_blocks - 1): + # dense blocks + net, num_filters = _dense_block( + net, + num_layers[i], + num_filters, + growth_rate, + scope="dense_block" + str(i + 1), + ) + + # Add transition_block + net, num_filters = _transition_block( + net, + num_filters, + compression=compression, + scope="transition_block" + str(i + 1), + ) + + net, num_filters = _dense_block( + net, + num_layers[-1], + num_filters, + growth_rate, + scope="dense_block" + str(num_dense_blocks), + ) + + # final blocks + with tf.variable_scope("final_block", [inputs]): + net = slim.batch_norm(net) + net = tf.nn.relu(net) + net = _global_avg_pool2d(net, scope="global_avg_pool") + + net = slim.conv2d( + net, + num_classes, + 1, + biases_initializer=tf.zeros_initializer(), + scope="logits", + ) + + end_points = slim.utils.convert_collection_to_dict(end_points_collection) + + # if num_classes is not None: + # end_points['predictions'] = slim.softmax(net, scope='predictions') + + return net, end_points + + +def densenet121( + inputs, num_classes=1000, data_format="NHWC", is_training=True, reuse=None +): + return densenet( + inputs, + num_classes=num_classes, + reduction=0.5, + growth_rate=32, + num_filters=64, + num_layers=[6, 12, 24, 16], + data_format=data_format, + is_training=is_training, + reuse=reuse, + scope="densenet121", + ) - net = slim.conv2d(net, num_classes, 1, - biases_initializer=tf.zeros_initializer(), - scope='logits') - end_points = slim.utils.convert_collection_to_dict( - end_points_collection) - - # if num_classes is not None: - # end_points['predictions'] = slim.softmax(net, scope='predictions') - - return net, end_points +densenet121.default_image_size = 224 -def densenet121(inputs, num_classes=1000, data_format='NHWC', is_training=True, reuse=None): - return densenet(inputs, - num_classes=num_classes, - reduction=0.5, - growth_rate=32, - num_filters=64, - num_layers=[6,12,24,16], - data_format=data_format, - is_training=is_training, - reuse=reuse, - scope='densenet121') -densenet121.default_image_size = 224 +def densenet161( + inputs, num_classes=1000, data_format="NHWC", is_training=True, reuse=None +): + return densenet( + inputs, + num_classes=num_classes, + reduction=0.5, + growth_rate=48, + num_filters=96, + num_layers=[6, 12, 36, 24], + data_format=data_format, + is_training=is_training, + reuse=reuse, + scope="densenet161", + ) -def densenet161(inputs, num_classes=1000, data_format='NHWC', is_training=True, reuse=None): - return densenet(inputs, - num_classes=num_classes, - reduction=0.5, - growth_rate=48, - num_filters=96, - num_layers=[6,12,36,24], - data_format=data_format, - is_training=is_training, - reuse=reuse, - scope='densenet161') densenet161.default_image_size = 224 -def densenet169(inputs, num_classes=1000, data_format='NHWC', is_training=True, reuse=None): - return densenet(inputs, - num_classes=num_classes, - reduction=0.5, - growth_rate=32, - num_filters=64, - num_layers=[6,12,32,32], - data_format=data_format, - is_training=is_training, - reuse=reuse, - scope='densenet169') -densenet169.default_image_size = 224 +def densenet169( + inputs, num_classes=1000, data_format="NHWC", is_training=True, reuse=None +): + return densenet( + inputs, + num_classes=num_classes, + reduction=0.5, + growth_rate=32, + num_filters=64, + num_layers=[6, 12, 32, 32], + data_format=data_format, + is_training=is_training, + reuse=reuse, + scope="densenet169", + ) -def densenet_arg_scope(weight_decay=1e-4, - batch_norm_decay=0.99, - batch_norm_epsilon=1.1e-5, - data_format='NHWC'): - with slim.arg_scope([slim.conv2d, slim.batch_norm, slim.avg_pool2d, slim.max_pool2d, - _conv_block, _global_avg_pool2d], - data_format=data_format): - with slim.arg_scope([slim.conv2d], - # weights_regularizer=slim.l2_regularizer(weight_decay), - weights_initializer=tf.zeros_initializer(), - activation_fn=None, - biases_initializer=None): - with slim.arg_scope([slim.batch_norm], - scale=True, - decay=batch_norm_decay, - epsilon=batch_norm_epsilon) as scope: - return scope +densenet169.default_image_size = 224 +def densenet_arg_scope( + weight_decay=1e-4, + batch_norm_decay=0.99, + batch_norm_epsilon=1.1e-5, + data_format="NHWC", +): + with slim.arg_scope( + [ + slim.conv2d, + slim.batch_norm, + slim.avg_pool2d, + slim.max_pool2d, + _conv_block, + _global_avg_pool2d, + ], + data_format=data_format, + ): + with slim.arg_scope( + [slim.conv2d], + # weights_regularizer=slim.l2_regularizer(weight_decay), + weights_initializer=tf.zeros_initializer(), + activation_fn=None, + biases_initializer=None, + ): + with slim.arg_scope( + [slim.batch_norm], + scale=True, + decay=batch_norm_decay, + epsilon=batch_norm_epsilon, + ) as scope: + return scope diff --git a/Athos/Networks/DenseNet/nets_factory.py b/Athos/Networks/DenseNet/nets_factory.py index 3b324f86..245e04db 100644 --- a/Athos/Networks/DenseNet/nets_factory.py +++ b/Athos/Networks/DenseNet/nets_factory.py @@ -26,45 +26,50 @@ slim = tf.contrib.slim networks_map = { - 'densenet121': densenet.densenet121, - 'densenet161': densenet.densenet161, - 'densenet169': densenet.densenet169, - } + "densenet121": densenet.densenet121, + "densenet161": densenet.densenet161, + "densenet169": densenet.densenet169, +} arg_scopes_map = { - 'densenet121': densenet.densenet_arg_scope, - 'densenet161': densenet.densenet_arg_scope, - 'densenet169': densenet.densenet_arg_scope, - } + "densenet121": densenet.densenet_arg_scope, + "densenet161": densenet.densenet_arg_scope, + "densenet169": densenet.densenet_arg_scope, +} -def get_network_fn(name, num_classes, weight_decay=0.0, data_format='NHWC', - is_training=False): - """Returns a network_fn such as `logits, end_points = network_fn(images)`. +def get_network_fn( + name, num_classes, weight_decay=0.0, data_format="NHWC", is_training=False +): + """Returns a network_fn such as `logits, end_points = network_fn(images)`. - Args: - name: The name of the network. - num_classes: The number of classes to use for classification. - weight_decay: The l2 coefficient for the model weights. - is_training: `True` if the model is being used for training and `False` - otherwise. + Args: + name: The name of the network. + num_classes: The number of classes to use for classification. + weight_decay: The l2 coefficient for the model weights. + is_training: `True` if the model is being used for training and `False` + otherwise. - Returns: - network_fn: A function that applies the model to a batch of images. It has - the following signature: - logits, end_points = network_fn(images) - Raises: - ValueError: If network `name` is not recognized. - """ - if name not in networks_map: - raise ValueError('Name of network unknown %s' % name) - arg_scope = arg_scopes_map[name](weight_decay=weight_decay, data_format=data_format) - func = networks_map[name] - @functools.wraps(func) - def network_fn(images): - with slim.arg_scope(arg_scope): - return func(images, num_classes, data_format=data_format, is_training=is_training) - if hasattr(func, 'default_image_size'): - network_fn.default_image_size = func.default_image_size + Returns: + network_fn: A function that applies the model to a batch of images. It has + the following signature: + logits, end_points = network_fn(images) + Raises: + ValueError: If network `name` is not recognized. + """ + if name not in networks_map: + raise ValueError("Name of network unknown %s" % name) + arg_scope = arg_scopes_map[name](weight_decay=weight_decay, data_format=data_format) + func = networks_map[name] - return network_fn + @functools.wraps(func) + def network_fn(images): + with slim.arg_scope(arg_scope): + return func( + images, num_classes, data_format=data_format, is_training=is_training + ) + + if hasattr(func, "default_image_size"): + network_fn.default_image_size = func.default_image_size + + return network_fn diff --git a/Athos/Networks/Lenet/lenetLarge_mnist_inference.py b/Athos/Networks/Lenet/lenetLarge_mnist_inference.py index 5e5c77e6..3f0e01cd 100644 --- a/Athos/Networks/Lenet/lenetLarge_mnist_inference.py +++ b/Athos/Networks/Lenet/lenetLarge_mnist_inference.py @@ -36,152 +36,178 @@ import tensorflow as tf import time -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "TFCompiler")) import DumpTFMtData FLAGS = None def deepnn(x): - """deepnn builds the graph for a deep net for classifying digits. - - Args: - x: an input tensor with the dimensions (N_examples, 784), where 784 is the - number of pixels in a standard MNIST image. - - Returns: - A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values - equal to the logits of classifying the digit into one of 10 classes (the - digits 0-9). keep_prob is a scalar placeholder for the probability of - dropout. - """ - # Reshape to use within a convolutional neural net. - # Last dimension is for "features" - there is only one here, since images are - # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc. - with tf.name_scope('reshape'): - x_image = tf.reshape(x, [-1, 28, 28, 1]) - - # First convolutional layer - maps one grayscale image to 32 feature maps. - with tf.name_scope('conv1'): - W_conv1 = weight_variable([5, 5, 1, 32]) - b_conv1 = bias_variable([32]) - h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) - - # Pooling layer - downsamples by 2X. - with tf.name_scope('pool1'): - h_pool1 = max_pool_2x2(h_conv1) - - # Second convolutional layer -- maps 32 feature maps to 64. - with tf.name_scope('conv2'): - W_conv2 = weight_variable([5, 5, 32, 64]) - b_conv2 = bias_variable([64]) - h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) - - # Second pooling layer. - with tf.name_scope('pool2'): - h_pool2 = max_pool_2x2(h_conv2) - - # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image - # is down to 7x7x64 feature maps -- maps this to 1024 features. - with tf.name_scope('fc1'): - W_fc1 = weight_variable([3136, 512]) - b_fc1 = bias_variable([512]) - - h_pool2_flat = tf.reshape(h_pool2, [-1, 3136]) - h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) - - # Dropout - controls the complexity of the model, prevents co-adaptation of - # features. - # with tf.name_scope('dropout'): - keep_prob = tf.placeholder(tf.float32) - # h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) - - # Map the 1024 features to 10 classes, one for each digit - with tf.name_scope('fc2'): - W_fc2 = weight_variable([512, 10]) - b_fc2 = bias_variable([10]) - - y_conv = tf.matmul(h_fc1, W_fc2) + b_fc2 - return y_conv, keep_prob, [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2], [W_conv1] + """deepnn builds the graph for a deep net for classifying digits. + + Args: + x: an input tensor with the dimensions (N_examples, 784), where 784 is the + number of pixels in a standard MNIST image. + + Returns: + A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values + equal to the logits of classifying the digit into one of 10 classes (the + digits 0-9). keep_prob is a scalar placeholder for the probability of + dropout. + """ + # Reshape to use within a convolutional neural net. + # Last dimension is for "features" - there is only one here, since images are + # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc. + with tf.name_scope("reshape"): + x_image = tf.reshape(x, [-1, 28, 28, 1]) + + # First convolutional layer - maps one grayscale image to 32 feature maps. + with tf.name_scope("conv1"): + W_conv1 = weight_variable([5, 5, 1, 32]) + b_conv1 = bias_variable([32]) + h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) + + # Pooling layer - downsamples by 2X. + with tf.name_scope("pool1"): + h_pool1 = max_pool_2x2(h_conv1) + + # Second convolutional layer -- maps 32 feature maps to 64. + with tf.name_scope("conv2"): + W_conv2 = weight_variable([5, 5, 32, 64]) + b_conv2 = bias_variable([64]) + h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) + + # Second pooling layer. + with tf.name_scope("pool2"): + h_pool2 = max_pool_2x2(h_conv2) + + # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image + # is down to 7x7x64 feature maps -- maps this to 1024 features. + with tf.name_scope("fc1"): + W_fc1 = weight_variable([3136, 512]) + b_fc1 = bias_variable([512]) + + h_pool2_flat = tf.reshape(h_pool2, [-1, 3136]) + h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) + + # Dropout - controls the complexity of the model, prevents co-adaptation of + # features. + # with tf.name_scope('dropout'): + keep_prob = tf.placeholder(tf.float32) + # h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) + + # Map the 1024 features to 10 classes, one for each digit + with tf.name_scope("fc2"): + W_fc2 = weight_variable([512, 10]) + b_fc2 = bias_variable([10]) + + y_conv = tf.matmul(h_fc1, W_fc2) + b_fc2 + return ( + y_conv, + keep_prob, + [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2], + [W_conv1], + ) + def conv2d(x, W): - """conv2d returns a 2d convolution layer with full stride.""" - return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') + """conv2d returns a 2d convolution layer with full stride.""" + return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="SAME") def max_pool_2x2(x): - """max_pool_2x2 downsamples a feature map by 2X.""" - return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], - strides=[1, 2, 2, 1], padding='SAME') + """max_pool_2x2 downsamples a feature map by 2X.""" + return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME") + def weight_variable(shape): - """weight_variable generates a weight variable of a given shape.""" - # initial = tf.truncated_normal(shape, stddev=0.1) - initial = tf.constant(0.25, shape=shape) - return tf.Variable(initial) + """weight_variable generates a weight variable of a given shape.""" + # initial = tf.truncated_normal(shape, stddev=0.1) + initial = tf.constant(0.25, shape=shape) + return tf.Variable(initial) def bias_variable(shape): - """bias_variable generates a bias variable of a given shape.""" - initial = tf.constant(0.25, shape=shape) - return tf.Variable(initial) + """bias_variable generates a bias variable of a given shape.""" + initial = tf.constant(0.25, shape=shape) + return tf.Variable(initial) + def findLabel(oneHotAns): - for i in range(10): - if oneHotAns[i] == 1.0: - return i - return -1 + for i in range(10): + if oneHotAns[i] == 1.0: + return i + return -1 + def main(_): - mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True,seed=1) - x = tf.placeholder(tf.float32, [None, 784]) - y_ = tf.placeholder(tf.float32, [None, 10]) - y_conv, keep_prob, modelWeights, prTensors = deepnn(x) - - with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - - if len(sys.argv)!=2: - print("Need the mnist image number to run inference on.") - exit(-1) - - curImageNum = int(sys.argv[1]) - imagex = mnist.test.images[curImageNum:curImageNum+1,:] - imagey = mnist.test.labels[curImageNum:curImageNum+1,:] - keep_prob_value = 1.0 - feed_dict = {x:imagex, y_:imagey, keep_prob:keep_prob_value} - - output_tensor = None - gg = tf.get_default_graph() - for node in gg.as_graph_def().node: - if node.name == 'fc2/add': - output_tensor = gg.get_operation_by_name(node.name).outputs[0] - optimized_graph_def = DumpTFMtData.save_graph_metadata(output_tensor, sess, feed_dict) - - saver = tf.train.Saver(modelWeights) - saver.restore(sess, './TrainedModel/lenetLargeModel') - - start_time = time.time() - prediction = sess.run([y_conv, keep_prob],feed_dict=feed_dict) - duration = time.time() - start_time - - print("Duration of execution : ", duration) - print('Result ::::::::: \n', prediction[0]) - - print("Prediction: ", np.argmax(prediction[0])) - print("Actual label: ", findLabel(imagey[0])) - - trainVarsName = [] - for node in optimized_graph_def.node: - if node.op=="VariableV2": - trainVarsName.append(node.name) - trainVars = list(map(lambda x : tf.get_default_graph().get_operation_by_name(x).outputs[0] , trainVarsName)) - DumpTFMtData.dumpImgAndWeightsDataSeparate(sess, imagex[0], trainVars, 'LenetLarge_img_{0}.inp'.format(curImageNum), 'LenetLarge_weights_{0}.inp'.format(curImageNum), 15) - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--data_dir', type=str, - default='/tmp/data', - help='Directory for storing input data') - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) \ No newline at end of file + mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True, seed=1) + x = tf.placeholder(tf.float32, [None, 784]) + y_ = tf.placeholder(tf.float32, [None, 10]) + y_conv, keep_prob, modelWeights, prTensors = deepnn(x) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + + if len(sys.argv) != 2: + print("Need the mnist image number to run inference on.") + exit(-1) + + curImageNum = int(sys.argv[1]) + imagex = mnist.test.images[curImageNum : curImageNum + 1, :] + imagey = mnist.test.labels[curImageNum : curImageNum + 1, :] + keep_prob_value = 1.0 + feed_dict = {x: imagex, y_: imagey, keep_prob: keep_prob_value} + + output_tensor = None + gg = tf.get_default_graph() + for node in gg.as_graph_def().node: + if node.name == "fc2/add": + output_tensor = gg.get_operation_by_name(node.name).outputs[0] + optimized_graph_def = DumpTFMtData.save_graph_metadata( + output_tensor, sess, feed_dict + ) + + saver = tf.train.Saver(modelWeights) + saver.restore(sess, "./TrainedModel/lenetLargeModel") + + start_time = time.time() + prediction = sess.run([y_conv, keep_prob], feed_dict=feed_dict) + duration = time.time() - start_time + + print("Duration of execution : ", duration) + print("Result ::::::::: \n", prediction[0]) + + print("Prediction: ", np.argmax(prediction[0])) + print("Actual label: ", findLabel(imagey[0])) + + trainVarsName = [] + for node in optimized_graph_def.node: + if node.op == "VariableV2": + trainVarsName.append(node.name) + trainVars = list( + map( + lambda x: tf.get_default_graph().get_operation_by_name(x).outputs[0], + trainVarsName, + ) + ) + DumpTFMtData.dumpImgAndWeightsDataSeparate( + sess, + imagex[0], + trainVars, + "LenetLarge_img_{0}.inp".format(curImageNum), + "LenetLarge_weights_{0}.inp".format(curImageNum), + 15, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_dir", + type=str, + default="/tmp/data", + help="Directory for storing input data", + ) + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/Athos/Networks/Lenet/lenetLarge_mnist_train.py b/Athos/Networks/Lenet/lenetLarge_mnist_train.py index 90af6f5e..6313abd6 100644 --- a/Athos/Networks/Lenet/lenetLarge_mnist_train.py +++ b/Athos/Networks/Lenet/lenetLarge_mnist_train.py @@ -40,144 +40,151 @@ def deepnn(x): - """deepnn builds the graph for a deep net for classifying digits. - - Args: - x: an input tensor with the dimensions (N_examples, 784), where 784 is the - number of pixels in a standard MNIST image. - - Returns: - A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values - equal to the logits of classifying the digit into one of 10 classes (the - digits 0-9). keep_prob is a scalar placeholder for the probability of - dropout. - """ - # Reshape to use within a convolutional neural net. - # Last dimension is for "features" - there is only one here, since images are - # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc. - with tf.name_scope('reshape'): - x_image = tf.reshape(x, [-1, 28, 28, 1]) - - # First convolutional layer - maps one grayscale image to 32 feature maps. - with tf.name_scope('conv1'): - W_conv1 = weight_variable([5, 5, 1, 32]) - b_conv1 = bias_variable([32]) - h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) - - # Pooling layer - downsamples by 2X. - with tf.name_scope('pool1'): - h_pool1 = max_pool_2x2(h_conv1) - - # Second convolutional layer -- maps 32 feature maps to 64. - with tf.name_scope('conv2'): - W_conv2 = weight_variable([5, 5, 32, 64]) - b_conv2 = bias_variable([64]) - h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) - - # Second pooling layer. - with tf.name_scope('pool2'): - h_pool2 = max_pool_2x2(h_conv2) - - # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image - # is down to 7x7x64 feature maps -- maps this to 1024 features. - with tf.name_scope('fc1'): - W_fc1 = weight_variable([3136, 512]) - b_fc1 = bias_variable([512]) - - h_pool2_flat = tf.reshape(h_pool2, [-1, 3136]) - h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) - - # Dropout - controls the complexity of the model, prevents co-adaptation of - # features. - with tf.name_scope('dropout'): - keep_prob = tf.placeholder(tf.float32) - h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) - - # Map the 1024 features to 10 classes, one for each digit - with tf.name_scope('fc2'): - W_fc2 = weight_variable([512, 10]) - b_fc2 = bias_variable([10]) - - y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 - return y_conv, keep_prob, [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2] + """deepnn builds the graph for a deep net for classifying digits. + + Args: + x: an input tensor with the dimensions (N_examples, 784), where 784 is the + number of pixels in a standard MNIST image. + + Returns: + A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values + equal to the logits of classifying the digit into one of 10 classes (the + digits 0-9). keep_prob is a scalar placeholder for the probability of + dropout. + """ + # Reshape to use within a convolutional neural net. + # Last dimension is for "features" - there is only one here, since images are + # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc. + with tf.name_scope("reshape"): + x_image = tf.reshape(x, [-1, 28, 28, 1]) + + # First convolutional layer - maps one grayscale image to 32 feature maps. + with tf.name_scope("conv1"): + W_conv1 = weight_variable([5, 5, 1, 32]) + b_conv1 = bias_variable([32]) + h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) + + # Pooling layer - downsamples by 2X. + with tf.name_scope("pool1"): + h_pool1 = max_pool_2x2(h_conv1) + + # Second convolutional layer -- maps 32 feature maps to 64. + with tf.name_scope("conv2"): + W_conv2 = weight_variable([5, 5, 32, 64]) + b_conv2 = bias_variable([64]) + h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) + + # Second pooling layer. + with tf.name_scope("pool2"): + h_pool2 = max_pool_2x2(h_conv2) + + # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image + # is down to 7x7x64 feature maps -- maps this to 1024 features. + with tf.name_scope("fc1"): + W_fc1 = weight_variable([3136, 512]) + b_fc1 = bias_variable([512]) + + h_pool2_flat = tf.reshape(h_pool2, [-1, 3136]) + h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) + + # Dropout - controls the complexity of the model, prevents co-adaptation of + # features. + with tf.name_scope("dropout"): + keep_prob = tf.placeholder(tf.float32) + h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) + + # Map the 1024 features to 10 classes, one for each digit + with tf.name_scope("fc2"): + W_fc2 = weight_variable([512, 10]) + b_fc2 = bias_variable([10]) + + y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 + return ( + y_conv, + keep_prob, + [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2], + ) def conv2d(x, W): - """conv2d returns a 2d convolution layer with full stride.""" - return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') + """conv2d returns a 2d convolution layer with full stride.""" + return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="SAME") def max_pool_2x2(x): - """max_pool_2x2 downsamples a feature map by 2X.""" - return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], - strides=[1, 2, 2, 1], padding='SAME') + """max_pool_2x2 downsamples a feature map by 2X.""" + return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME") def weight_variable(shape): - """weight_variable generates a weight variable of a given shape.""" - initial = tf.truncated_normal(shape, stddev=0.1) - return tf.Variable(initial) + """weight_variable generates a weight variable of a given shape.""" + initial = tf.truncated_normal(shape, stddev=0.1) + return tf.Variable(initial) def bias_variable(shape): - """bias_variable generates a bias variable of a given shape.""" - initial = tf.constant(0.1, shape=shape) - return tf.Variable(initial) + """bias_variable generates a bias variable of a given shape.""" + initial = tf.constant(0.1, shape=shape) + return tf.Variable(initial) def main(_): - # Import data - mnist = input_data.read_data_sets(FLAGS.data_dir) - - # Create the model - x = tf.placeholder(tf.float32, [None, 784]) - - # Define loss and optimizer - y_ = tf.placeholder(tf.int64, [None]) - - # Build the graph for the deep net - y_conv, keep_prob, evalTensors = deepnn(x) - - with tf.name_scope('loss'): - cross_entropy = tf.losses.sparse_softmax_cross_entropy( - labels=y_, logits=y_conv) - cross_entropy = tf.reduce_mean(cross_entropy) - - with tf.name_scope('adam_optimizer'): - train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) - - with tf.name_scope('accuracy'): - correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_) - correct_prediction = tf.cast(correct_prediction, tf.float32) - accuracy = tf.reduce_mean(correct_prediction) - - with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - for i in range(1000): - batch = mnist.train.next_batch(50) - if i % 100 == 0: - train_accuracy = accuracy.eval(feed_dict={ - x: batch[0], y_: batch[1], keep_prob: 1.0}) - print('step %d, training accuracy %g' % (i, train_accuracy)) - train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) - - # compute in batches to avoid OOM on GPUs - accuracy_l = [] - for _ in range(20): - batch = mnist.test.next_batch(500, shuffle=False) - accuracy_l.append(accuracy.eval(feed_dict={x: batch[0], - y_: batch[1], - keep_prob: 1.0})) - print('test accuracy %g' % numpy.mean(accuracy_l)) - - # Dump trained parameters to a checkpoint file - saver = tf.train.Saver(evalTensors) - print(saver.save(sess, './TrainedModel/lenetLargeModel')) - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--data_dir', type=str, - default='/tmp/tensorflow/mnist/input_data', - help='Directory for storing input data') - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) \ No newline at end of file + # Import data + mnist = input_data.read_data_sets(FLAGS.data_dir) + + # Create the model + x = tf.placeholder(tf.float32, [None, 784]) + + # Define loss and optimizer + y_ = tf.placeholder(tf.int64, [None]) + + # Build the graph for the deep net + y_conv, keep_prob, evalTensors = deepnn(x) + + with tf.name_scope("loss"): + cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y_conv) + cross_entropy = tf.reduce_mean(cross_entropy) + + with tf.name_scope("adam_optimizer"): + train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) + + with tf.name_scope("accuracy"): + correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_) + correct_prediction = tf.cast(correct_prediction, tf.float32) + accuracy = tf.reduce_mean(correct_prediction) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + for i in range(1000): + batch = mnist.train.next_batch(50) + if i % 100 == 0: + train_accuracy = accuracy.eval( + feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0} + ) + print("step %d, training accuracy %g" % (i, train_accuracy)) + train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) + + # compute in batches to avoid OOM on GPUs + accuracy_l = [] + for _ in range(20): + batch = mnist.test.next_batch(500, shuffle=False) + accuracy_l.append( + accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0}) + ) + print("test accuracy %g" % numpy.mean(accuracy_l)) + + # Dump trained parameters to a checkpoint file + saver = tf.train.Saver(evalTensors) + print(saver.save(sess, "./TrainedModel/lenetLargeModel")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_dir", + type=str, + default="/tmp/tensorflow/mnist/input_data", + help="Directory for storing input data", + ) + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/Athos/Networks/Lenet/lenetSmall_mnist_inference.py b/Athos/Networks/Lenet/lenetSmall_mnist_inference.py index d9adbaa5..d4cfd68a 100644 --- a/Athos/Networks/Lenet/lenetSmall_mnist_inference.py +++ b/Athos/Networks/Lenet/lenetSmall_mnist_inference.py @@ -36,150 +36,178 @@ import tensorflow as tf import time -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "TFCompiler")) import DumpTFMtData FLAGS = None + def deepnn(x): - """deepnn builds the graph for a deep net for classifying digits. - - Args: - x: an input tensor with the dimensions (N_examples, 784), where 784 is the - number of pixels in a standard MNIST image. - - Returns: - A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values - equal to the logits of classifying the digit into one of 10 classes (the - digits 0-9). keep_prob is a scalar placeholder for the probability of - dropout. - """ - # Reshape to use within a convolutional neural net. - # Last dimension is for "features" - there is only one here, since images are - # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc. - with tf.name_scope('reshape'): - x_image = tf.reshape(x, [-1, 28, 28, 1]) - - # First convolutional layer - maps one grayscale image to 32 feature maps. - with tf.name_scope('conv1'): - W_conv1 = weight_variable([5, 5, 1, 16]) - b_conv1 = bias_variable([16]) - h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) - - # Pooling layer - downsamples by 2X. - with tf.name_scope('pool1'): - h_pool1 = max_pool_2x2(h_conv1) - - # Second convolutional layer -- maps 32 feature maps to 64. - with tf.name_scope('conv2'): - W_conv2 = weight_variable([5, 5, 16, 16]) - b_conv2 = bias_variable([16]) - h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) - - # Second pooling layer. - with tf.name_scope('pool2'): - h_pool2 = max_pool_2x2(h_conv2) - - # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image - # is down to 7x7x64 feature maps -- maps this to 1024 features. - with tf.name_scope('fc1'): - W_fc1 = weight_variable([256, 100]) - b_fc1 = bias_variable([100]) - - h_pool2_flat = tf.reshape(h_pool2, [-1, 256]) - h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) - - # Dropout - controls the complexity of the model, prevents co-adaptation of - # features. - # with tf.name_scope('dropout'): - keep_prob = tf.placeholder(tf.float32) - # h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) - - # Map the 1024 features to 10 classes, one for each digit - with tf.name_scope('fc2'): - W_fc2 = weight_variable([100, 10]) - b_fc2 = bias_variable([10]) - - y_conv = tf.matmul(h_fc1, W_fc2) + b_fc2 - return y_conv, keep_prob, [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2] + """deepnn builds the graph for a deep net for classifying digits. + + Args: + x: an input tensor with the dimensions (N_examples, 784), where 784 is the + number of pixels in a standard MNIST image. + + Returns: + A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values + equal to the logits of classifying the digit into one of 10 classes (the + digits 0-9). keep_prob is a scalar placeholder for the probability of + dropout. + """ + # Reshape to use within a convolutional neural net. + # Last dimension is for "features" - there is only one here, since images are + # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc. + with tf.name_scope("reshape"): + x_image = tf.reshape(x, [-1, 28, 28, 1]) + + # First convolutional layer - maps one grayscale image to 32 feature maps. + with tf.name_scope("conv1"): + W_conv1 = weight_variable([5, 5, 1, 16]) + b_conv1 = bias_variable([16]) + h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) + + # Pooling layer - downsamples by 2X. + with tf.name_scope("pool1"): + h_pool1 = max_pool_2x2(h_conv1) + + # Second convolutional layer -- maps 32 feature maps to 64. + with tf.name_scope("conv2"): + W_conv2 = weight_variable([5, 5, 16, 16]) + b_conv2 = bias_variable([16]) + h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) + + # Second pooling layer. + with tf.name_scope("pool2"): + h_pool2 = max_pool_2x2(h_conv2) + + # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image + # is down to 7x7x64 feature maps -- maps this to 1024 features. + with tf.name_scope("fc1"): + W_fc1 = weight_variable([256, 100]) + b_fc1 = bias_variable([100]) + + h_pool2_flat = tf.reshape(h_pool2, [-1, 256]) + h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) + + # Dropout - controls the complexity of the model, prevents co-adaptation of + # features. + # with tf.name_scope('dropout'): + keep_prob = tf.placeholder(tf.float32) + # h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) + + # Map the 1024 features to 10 classes, one for each digit + with tf.name_scope("fc2"): + W_fc2 = weight_variable([100, 10]) + b_fc2 = bias_variable([10]) + + y_conv = tf.matmul(h_fc1, W_fc2) + b_fc2 + return ( + y_conv, + keep_prob, + [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2], + ) + def conv2d(x, W): - """conv2d returns a 2d convolution layer with full stride.""" - return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID') + """conv2d returns a 2d convolution layer with full stride.""" + return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="VALID") + def max_pool_2x2(x): - """max_pool_2x2 downsamples a feature map by 2X.""" - return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], - strides=[1, 2, 2, 1], padding='VALID') + """max_pool_2x2 downsamples a feature map by 2X.""" + return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID") + def weight_variable(shape): - """weight_variable generates a weight variable of a given shape.""" - initial = tf.constant(0.25, shape=shape) - return tf.Variable(initial) + """weight_variable generates a weight variable of a given shape.""" + initial = tf.constant(0.25, shape=shape) + return tf.Variable(initial) + def bias_variable(shape): - """bias_variable generates a bias variable of a given shape.""" - initial = tf.constant(0.25, shape=shape) - return tf.Variable(initial) + """bias_variable generates a bias variable of a given shape.""" + initial = tf.constant(0.25, shape=shape) + return tf.Variable(initial) + def findLabel(oneHotAns): - for i in range(10): - if oneHotAns[i] == 1.0: - return i - return -1 + for i in range(10): + if oneHotAns[i] == 1.0: + return i + return -1 + def main(_): - mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True,seed=1) - x = tf.placeholder(tf.float32, [None, 784]) - y_ = tf.placeholder(tf.float32, [None, 10]) - y_conv, keep_prob, modelWeights = deepnn(x) - pred = tf.argmax(y_conv, 1) - - with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - - if len(sys.argv)!=2: - print("Need the mnist image number to run inference on.") - exit(-1) - - curImageNum = int(sys.argv[1]) - imagex = mnist.test.images[curImageNum:curImageNum+1,:] - imagey = mnist.test.labels[curImageNum:curImageNum+1,:] - keep_prob_value = 1.0 - feed_dict = {x:imagex, y_:imagey, keep_prob:keep_prob_value} - - output_tensor = None - gg = tf.get_default_graph() - for node in gg.as_graph_def().node: - # if node.name == 'fc2/add': - if node.name == 'ArgMax': - output_tensor = gg.get_operation_by_name(node.name).outputs[0] - optimized_graph_def = DumpTFMtData.save_graph_metadata(output_tensor, sess, feed_dict) - - saver = tf.train.Saver(modelWeights) - saver.restore(sess, './TrainedModel/lenetSmallModel') - - start_time = time.time() - prediction = sess.run([y_conv, keep_prob],feed_dict=feed_dict) - duration = time.time() - start_time - - print("Duration of execution : ", duration) - print('Result ::::::::: \n', prediction[0]) - - print("Prediction: ", np.argmax(prediction[0])) - print("Actual label: ", findLabel(imagey[0])) - - trainVarsName = [] - for node in optimized_graph_def.node: - if node.op=="VariableV2": - trainVarsName.append(node.name) - trainVars = list(map(lambda x : tf.get_default_graph().get_operation_by_name(x).outputs[0] , trainVarsName)) - DumpTFMtData.dumpImgAndWeightsDataSeparate(sess, imagex[0], trainVars, 'LenetSmall_img_{0}.inp'.format(curImageNum), 'LenetSmall_weights_{0}.inp'.format(curImageNum), 15) - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--data_dir', type=str, - default='/tmp/data', - help='Directory for storing input data') - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) \ No newline at end of file + mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True, seed=1) + x = tf.placeholder(tf.float32, [None, 784]) + y_ = tf.placeholder(tf.float32, [None, 10]) + y_conv, keep_prob, modelWeights = deepnn(x) + pred = tf.argmax(y_conv, 1) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + + if len(sys.argv) != 2: + print("Need the mnist image number to run inference on.") + exit(-1) + + curImageNum = int(sys.argv[1]) + imagex = mnist.test.images[curImageNum : curImageNum + 1, :] + imagey = mnist.test.labels[curImageNum : curImageNum + 1, :] + keep_prob_value = 1.0 + feed_dict = {x: imagex, y_: imagey, keep_prob: keep_prob_value} + + output_tensor = None + gg = tf.get_default_graph() + for node in gg.as_graph_def().node: + # if node.name == 'fc2/add': + if node.name == "ArgMax": + output_tensor = gg.get_operation_by_name(node.name).outputs[0] + optimized_graph_def = DumpTFMtData.save_graph_metadata( + output_tensor, sess, feed_dict + ) + + saver = tf.train.Saver(modelWeights) + saver.restore(sess, "./TrainedModel/lenetSmallModel") + + start_time = time.time() + prediction = sess.run([y_conv, keep_prob], feed_dict=feed_dict) + duration = time.time() - start_time + + print("Duration of execution : ", duration) + print("Result ::::::::: \n", prediction[0]) + + print("Prediction: ", np.argmax(prediction[0])) + print("Actual label: ", findLabel(imagey[0])) + + trainVarsName = [] + for node in optimized_graph_def.node: + if node.op == "VariableV2": + trainVarsName.append(node.name) + trainVars = list( + map( + lambda x: tf.get_default_graph().get_operation_by_name(x).outputs[0], + trainVarsName, + ) + ) + DumpTFMtData.dumpImgAndWeightsDataSeparate( + sess, + imagex[0], + trainVars, + "LenetSmall_img_{0}.inp".format(curImageNum), + "LenetSmall_weights_{0}.inp".format(curImageNum), + 15, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_dir", + type=str, + default="/tmp/data", + help="Directory for storing input data", + ) + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/Athos/Networks/Lenet/lenetSmall_mnist_train.py b/Athos/Networks/Lenet/lenetSmall_mnist_train.py index 2229fcd2..7a732e11 100644 --- a/Athos/Networks/Lenet/lenetSmall_mnist_train.py +++ b/Athos/Networks/Lenet/lenetSmall_mnist_train.py @@ -40,139 +40,151 @@ def deepnn(x): - """deepnn builds the graph for a deep net for classifying digits. - - Args: - x: an input tensor with the dimensions (N_examples, 784), where 784 is the - number of pixels in a standard MNIST image. - - Returns: - A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values - equal to the logits of classifying the digit into one of 10 classes (the - digits 0-9). keep_prob is a scalar placeholder for the probability of - dropout. - """ - # Reshape to use within a convolutional neural net. - # Last dimension is for "features" - there is only one here, since images are - # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc. - with tf.name_scope('reshape'): - x_image = tf.reshape(x, [-1, 28, 28, 1]) - - # First convolutional layer - maps one grayscale image to 32 feature maps. - with tf.name_scope('conv1'): - W_conv1 = weight_variable([5, 5, 1, 16]) - b_conv1 = bias_variable([16]) - h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) - - # Pooling layer - downsamples by 2X. - with tf.name_scope('pool1'): - h_pool1 = max_pool_2x2(h_conv1) - - # Second convolutional layer -- maps 32 feature maps to 64. - with tf.name_scope('conv2'): - W_conv2 = weight_variable([5, 5, 16, 16]) - b_conv2 = bias_variable([16]) - h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) - - # Second pooling layer. - with tf.name_scope('pool2'): - h_pool2 = max_pool_2x2(h_conv2) - - # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image - # is down to 7x7x64 feature maps -- maps this to 1024 features. - with tf.name_scope('fc1'): - W_fc1 = weight_variable([256, 100]) - b_fc1 = bias_variable([100]) - - h_pool2_flat = tf.reshape(h_pool2, [-1, 256]) - h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) - - # Dropout - controls the complexity of the model, prevents co-adaptation of - # features. - with tf.name_scope('dropout'): - keep_prob = tf.placeholder(tf.float32) - h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) - - # Map the 1024 features to 10 classes, one for each digit - with tf.name_scope('fc2'): - W_fc2 = weight_variable([100, 10]) - b_fc2 = bias_variable([10]) - - y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 - return y_conv, keep_prob, [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2] + """deepnn builds the graph for a deep net for classifying digits. + + Args: + x: an input tensor with the dimensions (N_examples, 784), where 784 is the + number of pixels in a standard MNIST image. + + Returns: + A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values + equal to the logits of classifying the digit into one of 10 classes (the + digits 0-9). keep_prob is a scalar placeholder for the probability of + dropout. + """ + # Reshape to use within a convolutional neural net. + # Last dimension is for "features" - there is only one here, since images are + # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc. + with tf.name_scope("reshape"): + x_image = tf.reshape(x, [-1, 28, 28, 1]) + + # First convolutional layer - maps one grayscale image to 32 feature maps. + with tf.name_scope("conv1"): + W_conv1 = weight_variable([5, 5, 1, 16]) + b_conv1 = bias_variable([16]) + h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) + + # Pooling layer - downsamples by 2X. + with tf.name_scope("pool1"): + h_pool1 = max_pool_2x2(h_conv1) + + # Second convolutional layer -- maps 32 feature maps to 64. + with tf.name_scope("conv2"): + W_conv2 = weight_variable([5, 5, 16, 16]) + b_conv2 = bias_variable([16]) + h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) + + # Second pooling layer. + with tf.name_scope("pool2"): + h_pool2 = max_pool_2x2(h_conv2) + + # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image + # is down to 7x7x64 feature maps -- maps this to 1024 features. + with tf.name_scope("fc1"): + W_fc1 = weight_variable([256, 100]) + b_fc1 = bias_variable([100]) + + h_pool2_flat = tf.reshape(h_pool2, [-1, 256]) + h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) + + # Dropout - controls the complexity of the model, prevents co-adaptation of + # features. + with tf.name_scope("dropout"): + keep_prob = tf.placeholder(tf.float32) + h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) + + # Map the 1024 features to 10 classes, one for each digit + with tf.name_scope("fc2"): + W_fc2 = weight_variable([100, 10]) + b_fc2 = bias_variable([10]) + + y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 + return ( + y_conv, + keep_prob, + [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2], + ) + def conv2d(x, W): - """conv2d returns a 2d convolution layer with full stride.""" - return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID') + """conv2d returns a 2d convolution layer with full stride.""" + return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="VALID") + def max_pool_2x2(x): - """max_pool_2x2 downsamples a feature map by 2X.""" - return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], - strides=[1, 2, 2, 1], padding='VALID') + """max_pool_2x2 downsamples a feature map by 2X.""" + return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID") + def weight_variable(shape): - """weight_variable generates a weight variable of a given shape.""" - initial = tf.truncated_normal(shape, stddev=0.1) - return tf.Variable(initial) + """weight_variable generates a weight variable of a given shape.""" + initial = tf.truncated_normal(shape, stddev=0.1) + return tf.Variable(initial) + def bias_variable(shape): - """bias_variable generates a bias variable of a given shape.""" - initial = tf.constant(0.1, shape=shape) - return tf.Variable(initial) + """bias_variable generates a bias variable of a given shape.""" + initial = tf.constant(0.1, shape=shape) + return tf.Variable(initial) + def main(_): - # Import data - mnist = input_data.read_data_sets(FLAGS.data_dir) - - # Create the model - x = tf.placeholder(tf.float32, [None, 784]) - - # Define loss and optimizer - y_ = tf.placeholder(tf.int64, [None]) - - # Build the graph for the deep net - y_conv, keep_prob, evalTensors = deepnn(x) - - with tf.name_scope('loss'): - cross_entropy = tf.losses.sparse_softmax_cross_entropy( - labels=y_, logits=y_conv) - cross_entropy = tf.reduce_mean(cross_entropy) - - with tf.name_scope('adam_optimizer'): - train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) - - with tf.name_scope('accuracy'): - correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_) - correct_prediction = tf.cast(correct_prediction, tf.float32) - accuracy = tf.reduce_mean(correct_prediction) - - with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - for i in range(1000): - batch = mnist.train.next_batch(50) - if i % 100 == 0: - train_accuracy = accuracy.eval(feed_dict={ - x: batch[0], y_: batch[1], keep_prob: 1.0}) - print('step %d, training accuracy %g' % (i, train_accuracy)) - train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) - - # compute in batches to avoid OOM on GPUs - accuracy_l = [] - for _ in range(20): - batch = mnist.test.next_batch(500, shuffle=False) - accuracy_l.append(accuracy.eval(feed_dict={x: batch[0], - y_: batch[1], - keep_prob: 1.0})) - print('test accuracy %g' % numpy.mean(accuracy_l)) - - # Dump trained parameters to a checkpoint file - saver = tf.train.Saver(evalTensors) - print(saver.save(sess, './TrainedModel/lenetSmallModel')) - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--data_dir', type=str, - default='/tmp/tensorflow/mnist/input_data', - help='Directory for storing input data') - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) \ No newline at end of file + # Import data + mnist = input_data.read_data_sets(FLAGS.data_dir) + + # Create the model + x = tf.placeholder(tf.float32, [None, 784]) + + # Define loss and optimizer + y_ = tf.placeholder(tf.int64, [None]) + + # Build the graph for the deep net + y_conv, keep_prob, evalTensors = deepnn(x) + + with tf.name_scope("loss"): + cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y_conv) + cross_entropy = tf.reduce_mean(cross_entropy) + + with tf.name_scope("adam_optimizer"): + train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) + + with tf.name_scope("accuracy"): + correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_) + correct_prediction = tf.cast(correct_prediction, tf.float32) + accuracy = tf.reduce_mean(correct_prediction) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + for i in range(1000): + batch = mnist.train.next_batch(50) + if i % 100 == 0: + train_accuracy = accuracy.eval( + feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0} + ) + print("step %d, training accuracy %g" % (i, train_accuracy)) + train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) + + # compute in batches to avoid OOM on GPUs + accuracy_l = [] + for _ in range(20): + batch = mnist.test.next_batch(500, shuffle=False) + accuracy_l.append( + accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0}) + ) + print("test accuracy %g" % numpy.mean(accuracy_l)) + + # Dump trained parameters to a checkpoint file + saver = tf.train.Saver(evalTensors) + print(saver.save(sess, "./TrainedModel/lenetSmallModel")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_dir", + type=str, + default="/tmp/tensorflow/mnist/input_data", + help="Directory for storing input data", + ) + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/Athos/Networks/LogisticRegression/LogisticRegressionInfer.py b/Athos/Networks/LogisticRegression/LogisticRegressionInfer.py index 62178b50..b3772de8 100644 --- a/Athos/Networks/LogisticRegression/LogisticRegressionInfer.py +++ b/Athos/Networks/LogisticRegression/LogisticRegressionInfer.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -22,7 +22,7 @@ Parts of the code in this file is taken from https://github.com/aymericdamien/TensorFlow-Examples/. -''' +""" from __future__ import print_function import os, sys @@ -31,74 +31,91 @@ import matplotlib import tensorflow as tf import matplotlib.pyplot as plt -matplotlib.use('Agg') + +matplotlib.use("Agg") from tensorflow.examples.tutorials.mnist import input_data -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "TFCompiler")) import DumpTFMtData mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) -x = tf.placeholder(tf.float32, [None, 784]) # mnist data image of shape 28*28=784 +x = tf.placeholder(tf.float32, [None, 784]) # mnist data image of shape 28*28=784 # Model weights -W = tf.Variable(tf.constant(0.1, shape=[784,10])) +W = tf.Variable(tf.constant(0.1, shape=[784, 10])) b = tf.Variable(tf.constant(0.2, shape=[10])) # Construct model pred = tf.argmax(tf.matmul(x, W) + b, 1) init = tf.global_variables_initializer() + def findLabel(oneHotAns): - for i in range(10): - if oneHotAns[i] == 1.0: - return i - return -1 + for i in range(10): + if oneHotAns[i] == 1.0: + return i + return -1 + def saveImg(seqNum, imaged): - temp = imaged.reshape([28,28]) - # plt.gray() - plt.imshow(temp) - plt.savefig("MNIST_test_image" + str(seqNum) + ".png") + temp = imaged.reshape([28, 28]) + # plt.gray() + plt.imshow(temp) + plt.savefig("MNIST_test_image" + str(seqNum) + ".png") -with tf.Session() as sess: - sess.run(init) - - if len(sys.argv)!=2: - print("Need the mnist image number to run inference on.") - exit(-1) - - curImageNum = int(sys.argv[1]) - - imagex = mnist.test.images[curImageNum:curImageNum+1,:] - imagey = mnist.test.labels[curImageNum:curImageNum+1,:] - # saveImg(curImageNum, imagex[0]) # save the image so that so pic viewer can render it - feed_dict = {x:imagex} - - output_tensor = None - gg = tf.get_default_graph() - for node in gg.as_graph_def().node: - if node.name == 'ArgMax': - output_tensor = gg.get_operation_by_name(node.name).outputs[0] - optimized_graph_def = DumpTFMtData.save_graph_metadata(output_tensor, sess, feed_dict) - - evalTensors = [W,b] - saver = tf.train.Saver(evalTensors) - saver.restore(sess, './TrainedModel/model') - - start_time = time.time() - outp = sess.run(pred, feed_dict) - end_time = time.time() - - print("Duration of execution = ", (end_time-start_time)) - print(outp) - correctLabel = findLabel(imagey[0]) - print("Correct label = " + str(correctLabel)) - print("Dumping of values into .inp file...") - trainVarsName = [] - for node in optimized_graph_def.node: - if node.op=="VariableV2": - trainVarsName.append(node.name) - trainVars = list(map(lambda x : tf.get_default_graph().get_operation_by_name(x).outputs[0] , trainVarsName)) - DumpTFMtData.dumpImgAndWeightsDataSeparate(sess, imagex[0], trainVars, 'LR_img_{0}.inp'.format(curImageNum), 'LR_weights_{0}.inp'.format(curImageNum), 15) +with tf.Session() as sess: + sess.run(init) + + if len(sys.argv) != 2: + print("Need the mnist image number to run inference on.") + exit(-1) + + curImageNum = int(sys.argv[1]) + + imagex = mnist.test.images[curImageNum : curImageNum + 1, :] + imagey = mnist.test.labels[curImageNum : curImageNum + 1, :] + # saveImg(curImageNum, imagex[0]) # save the image so that so pic viewer can render it + feed_dict = {x: imagex} + + output_tensor = None + gg = tf.get_default_graph() + for node in gg.as_graph_def().node: + if node.name == "ArgMax": + output_tensor = gg.get_operation_by_name(node.name).outputs[0] + optimized_graph_def = DumpTFMtData.save_graph_metadata( + output_tensor, sess, feed_dict + ) + + evalTensors = [W, b] + saver = tf.train.Saver(evalTensors) + saver.restore(sess, "./TrainedModel/model") + + start_time = time.time() + outp = sess.run(pred, feed_dict) + end_time = time.time() + + print("Duration of execution = ", (end_time - start_time)) + print(outp) + correctLabel = findLabel(imagey[0]) + print("Correct label = " + str(correctLabel)) + print("Dumping of values into .inp file...") + trainVarsName = [] + for node in optimized_graph_def.node: + if node.op == "VariableV2": + trainVarsName.append(node.name) + trainVars = list( + map( + lambda x: tf.get_default_graph().get_operation_by_name(x).outputs[0], + trainVarsName, + ) + ) + DumpTFMtData.dumpImgAndWeightsDataSeparate( + sess, + imagex[0], + trainVars, + "LR_img_{0}.inp".format(curImageNum), + "LR_weights_{0}.inp".format(curImageNum), + 15, + ) diff --git a/Athos/Networks/LogisticRegression/LogisticRegressionTrain.py b/Athos/Networks/LogisticRegression/LogisticRegressionTrain.py index f2c4c0e7..4635348b 100644 --- a/Athos/Networks/LogisticRegression/LogisticRegressionTrain.py +++ b/Athos/Networks/LogisticRegression/LogisticRegressionTrain.py @@ -1,4 +1,4 @@ -''' +""" A logistic regression learning algorithm example using TensorFlow library. This example is using the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/) @@ -6,7 +6,7 @@ Project: https://github.com/aymericdamien/TensorFlow-Examples/ ** Original source code from above modified for CrypTFlow. ** -''' +""" from __future__ import print_function @@ -15,6 +15,7 @@ # Import MNIST data from tensorflow.examples.tutorials.mnist import input_data + mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) # Parameters @@ -24,18 +25,18 @@ display_step = 1 # tf Graph Input -x = tf.placeholder(tf.float32, [None, 784]) # mnist data image of shape 28*28=784 -y = tf.placeholder(tf.float32, [None, 10]) # 0-9 digits recognition => 10 classes +x = tf.placeholder(tf.float32, [None, 784]) # mnist data image of shape 28*28=784 +y = tf.placeholder(tf.float32, [None, 10]) # 0-9 digits recognition => 10 classes # Set model weights W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) # Construct model -pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax +pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax # Minimize error using cross entropy -cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1)) +cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1)) # Gradient Descent optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) @@ -50,28 +51,27 @@ # Training cycle for epoch in range(training_epochs): - avg_cost = 0. - total_batch = int(mnist.train.num_examples/batch_size) + avg_cost = 0.0 + total_batch = int(mnist.train.num_examples / batch_size) # Loop over all batches for i in range(total_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) # Run optimization op (backprop) and cost op (to get loss value) - _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, - y: batch_ys}) + _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: batch_ys}) # Compute average loss avg_cost += c / total_batch # Display logs per epoch step - if (epoch+1) % display_step == 0: - print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost)) + if (epoch + 1) % display_step == 0: + print("Epoch:", "%04d" % (epoch + 1), "cost=", "{:.9f}".format(avg_cost)) print("Optimization Finished!") ###################################### - evalTensors = [W,b] + evalTensors = [W, b] # First dump the tf model saver = tf.train.Saver(evalTensors) - print(saver.save(sess, './TrainedModel/model')) - + print(saver.save(sess, "./TrainedModel/model")) + ##################################### # Test model diff --git a/Athos/Networks/OtherBenchmarks/MiniONN_CIFAR.py b/Athos/Networks/OtherBenchmarks/MiniONN_CIFAR.py index b78de44c..e3c09920 100644 --- a/Athos/Networks/OtherBenchmarks/MiniONN_CIFAR.py +++ b/Athos/Networks/OtherBenchmarks/MiniONN_CIFAR.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" # # This is a NN over the CIFAR-10 dataset, used in the MiniONN paper. @@ -30,84 +30,92 @@ import os, sys import numpy as np import tensorflow as tf -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "TFCompiler")) import DumpTFMtData + def weight_variable(shape): - """weight_variable generates a weight variable of a given shape.""" - # initial = tf.truncated_normal(shape, stddev=0.1) - initial = tf.constant(0.01, shape=shape) - return tf.Variable(initial) + """weight_variable generates a weight variable of a given shape.""" + # initial = tf.truncated_normal(shape, stddev=0.1) + initial = tf.constant(0.01, shape=shape) + return tf.Variable(initial) + def bias_variable(shape): - """bias_variable generates a bias variable of a given shape.""" - initial = tf.constant(0.01, shape=shape) - return tf.Variable(initial) + """bias_variable generates a bias variable of a given shape.""" + initial = tf.constant(0.01, shape=shape) + return tf.Variable(initial) + x = tf.placeholder(tf.float32, [None, 32, 32, 3]) -# 1 -w_conv1 = weight_variable([3,3,3,64]) -conv1 = tf.nn.conv2d(x, w_conv1, strides=[1,1,1,1],padding='SAME') +# 1 +w_conv1 = weight_variable([3, 3, 3, 64]) +conv1 = tf.nn.conv2d(x, w_conv1, strides=[1, 1, 1, 1], padding="SAME") relu2 = tf.nn.relu(conv1) # 3 -w_conv3 = weight_variable([3,3,64,64]) -conv3 = tf.nn.conv2d(relu2, w_conv3, strides=[1,1,1,1],padding='SAME') +w_conv3 = weight_variable([3, 3, 64, 64]) +conv3 = tf.nn.conv2d(relu2, w_conv3, strides=[1, 1, 1, 1], padding="SAME") relu4 = tf.nn.relu(conv3) # 5 -avgpool5 = tf.nn.avg_pool(relu4, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') +avgpool5 = tf.nn.avg_pool( + relu4, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME" +) # 6 -w_conv6 = weight_variable([3,3,64,64]) -conv6 = tf.nn.conv2d(avgpool5, w_conv6, strides=[1,1,1,1], padding='SAME') +w_conv6 = weight_variable([3, 3, 64, 64]) +conv6 = tf.nn.conv2d(avgpool5, w_conv6, strides=[1, 1, 1, 1], padding="SAME") relu7 = tf.nn.relu(conv6) # 8 -w_conv8 = weight_variable([3,3,64,64]) -conv8 = tf.nn.conv2d(relu7, w_conv8, strides=[1,1,1,1], padding='SAME') +w_conv8 = weight_variable([3, 3, 64, 64]) +conv8 = tf.nn.conv2d(relu7, w_conv8, strides=[1, 1, 1, 1], padding="SAME") relu9 = tf.nn.relu(conv8) # 10 -avgpool10 = tf.nn.avg_pool(relu9, ksize=[1,2,2,1], strides=[1,1,1,1], padding='SAME') +avgpool10 = tf.nn.avg_pool( + relu9, ksize=[1, 2, 2, 1], strides=[1, 1, 1, 1], padding="SAME" +) # 11 -w_conv11 = weight_variable([3,3,64,64]) -conv11 = tf.nn.conv2d(avgpool10, w_conv11, strides=[1,2,2,1], padding='SAME') +w_conv11 = weight_variable([3, 3, 64, 64]) +conv11 = tf.nn.conv2d(avgpool10, w_conv11, strides=[1, 2, 2, 1], padding="SAME") relu12 = tf.nn.relu(conv11) # 13 -w_conv13 = weight_variable([1,1,64,64]) -conv13 = tf.nn.conv2d(relu12, w_conv13, strides=[1,1,1,1], padding='SAME') +w_conv13 = weight_variable([1, 1, 64, 64]) +conv13 = tf.nn.conv2d(relu12, w_conv13, strides=[1, 1, 1, 1], padding="SAME") relu14 = tf.nn.relu(conv13) # 15 -w_conv15 = weight_variable([1,1,64,16]) -conv15 = tf.nn.conv2d(relu14, w_conv15, strides=[1,1,1,1], padding='SAME') +w_conv15 = weight_variable([1, 1, 64, 16]) +conv15 = tf.nn.conv2d(relu14, w_conv15, strides=[1, 1, 1, 1], padding="SAME") relu16 = tf.nn.relu(conv15) # # 17 -w_fc17 = weight_variable([1024,10]) +w_fc17 = weight_variable([1024, 10]) b_fc17 = bias_variable([10]) -inpfc17 = tf.reshape(relu16, [-1,1024]) +inpfc17 = tf.reshape(relu16, [-1, 1024]) fc17 = tf.matmul(inpfc17, w_fc17) + b_fc17 -finalOut = tf.argmax(fc17,1) +finalOut = tf.argmax(fc17, 1) with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - imgData = np.full((1,32,32,3), 0.1) - feed_dict = {x: imgData} - - pred = sess.run(finalOut, feed_dict=feed_dict) - print(pred) - - output_tensor = None - gg = tf.get_default_graph() - for node in gg.as_graph_def().node: - if node.name == 'ArgMax': - output_tensor = gg.get_operation_by_name(node.name).outputs[0] - optimized_graph_def = DumpTFMtData.save_graph_metadata(output_tensor, sess, feed_dict) - - + sess.run(tf.global_variables_initializer()) + imgData = np.full((1, 32, 32, 3), 0.1) + feed_dict = {x: imgData} + + pred = sess.run(finalOut, feed_dict=feed_dict) + print(pred) + + output_tensor = None + gg = tf.get_default_graph() + for node in gg.as_graph_def().node: + if node.name == "ArgMax": + output_tensor = gg.get_operation_by_name(node.name).outputs[0] + optimized_graph_def = DumpTFMtData.save_graph_metadata( + output_tensor, sess, feed_dict + ) diff --git a/Athos/Networks/OtherBenchmarks/resnet32_cifar100.py b/Athos/Networks/OtherBenchmarks/resnet32_cifar100.py index eb439ef9..a67b807e 100644 --- a/Athos/Networks/OtherBenchmarks/resnet32_cifar100.py +++ b/Athos/Networks/OtherBenchmarks/resnet32_cifar100.py @@ -1,5 +1,3 @@ - - # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -32,11 +30,22 @@ from tensorflow.keras import backend from tensorflow.keras import layers -from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, add, ZeroPadding2D, GlobalAveragePooling2D, AveragePooling2D, Dense, Reshape, Lambda +from tensorflow.keras.layers import ( + Conv2D, + BatchNormalization, + Activation, + add, + ZeroPadding2D, + GlobalAveragePooling2D, + AveragePooling2D, + Dense, + Reshape, + Lambda, +) from tensorflow.keras.regularizers import l2 from tensorflow.keras.utils import get_custom_objects -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "TFCompiler")) import DumpTFMtData BATCH_NORM_DECAY = 0.997 @@ -44,7 +53,9 @@ L2_WEIGHT_DECAY = 2e-4 -def identity_building_block(input_tensor, kernel_size, filters, stage, block, training=None): +def identity_building_block( + input_tensor, kernel_size, filters, stage, block, training=None +): """The identity block is the block that has no conv layer at shortcut. Arguments: @@ -61,44 +72,45 @@ def identity_building_block(input_tensor, kernel_size, filters, stage, block, tr Output tensor for the block. """ filters1, filters2 = filters - bn_axis=1 - if tf.keras.backend.image_data_format() == 'channels_last': - bn_axis = 3 + bn_axis = 1 + if tf.keras.backend.image_data_format() == "channels_last": + bn_axis = 3 else: - bn_axis = 1 - - - x = Conv2D(filters1, kernel_size, - padding='same', - kernel_initializer='he_normal', - kernel_regularizer= - l2(L2_WEIGHT_DECAY), - bias_regularizer= - l2(L2_WEIGHT_DECAY))(input_tensor) - x = BatchNormalization(axis=bn_axis, - momentum=BATCH_NORM_DECAY, - epsilon=BATCH_NORM_EPSILON,fused=True)( - x, training=training) - - x = Activation('approx_activation')(x) - - x = Conv2D(filters2, kernel_size, - padding='same', - kernel_initializer='he_normal', - kernel_regularizer= - l2(L2_WEIGHT_DECAY), - bias_regularizer= - l2(L2_WEIGHT_DECAY))(x) - x = BatchNormalization(axis=bn_axis, - momentum=BATCH_NORM_DECAY, - epsilon=BATCH_NORM_EPSILON,fused=True)( - x, training=training) + bn_axis = 1 + + x = Conv2D( + filters1, + kernel_size, + padding="same", + kernel_initializer="he_normal", + kernel_regularizer=l2(L2_WEIGHT_DECAY), + bias_regularizer=l2(L2_WEIGHT_DECAY), + )(input_tensor) + x = BatchNormalization( + axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, fused=True + )(x, training=training) + + x = Activation("approx_activation")(x) + + x = Conv2D( + filters2, + kernel_size, + padding="same", + kernel_initializer="he_normal", + kernel_regularizer=l2(L2_WEIGHT_DECAY), + bias_regularizer=l2(L2_WEIGHT_DECAY), + )(x) + x = BatchNormalization( + axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, fused=True + )(x, training=training) x = add([x, input_tensor]) - x = Activation('approx_activation')(x) + x = Activation("approx_activation")(x) return x -def conv_building_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2), training=None): +def conv_building_block( + input_tensor, kernel_size, filters, stage, block, strides=(2, 2), training=None +): """A block that has a conv layer at shortcut. Arguments: @@ -120,53 +132,58 @@ def conv_building_block(input_tensor, kernel_size, filters, stage, block, stride And the shortcut should have strides=(2, 2) as well """ filters1, filters2 = filters - bn_axis=1 - if tf.keras.backend.image_data_format() == 'channels_last': - bn_axis = 3 + bn_axis = 1 + if tf.keras.backend.image_data_format() == "channels_last": + bn_axis = 3 else: - bn_axis = 1 - - x = Conv2D(filters1, kernel_size, strides=strides, - padding='same', - kernel_initializer='he_normal', - kernel_regularizer= - l2(L2_WEIGHT_DECAY), - bias_regularizer= - l2(L2_WEIGHT_DECAY))(input_tensor) - x = BatchNormalization(axis=bn_axis, - momentum=BATCH_NORM_DECAY, - epsilon=BATCH_NORM_EPSILON,fused=True)( - x, training=training) - x = Activation('approx_activation')(x) - - x = Conv2D(filters2, kernel_size, padding='same', - kernel_initializer='he_normal', - kernel_regularizer= - l2(L2_WEIGHT_DECAY), - bias_regularizer= - l2(L2_WEIGHT_DECAY))(x) - x = BatchNormalization(axis=bn_axis, - momentum=BATCH_NORM_DECAY, - epsilon=BATCH_NORM_EPSILON,fused=True)( - x, training=training) - - shortcut = Conv2D(filters2, (1, 1), strides=strides, - kernel_initializer='he_normal', - kernel_regularizer= - l2(L2_WEIGHT_DECAY), - bias_regularizer= - l2(L2_WEIGHT_DECAY))(input_tensor) + bn_axis = 1 + + x = Conv2D( + filters1, + kernel_size, + strides=strides, + padding="same", + kernel_initializer="he_normal", + kernel_regularizer=l2(L2_WEIGHT_DECAY), + bias_regularizer=l2(L2_WEIGHT_DECAY), + )(input_tensor) + x = BatchNormalization( + axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, fused=True + )(x, training=training) + x = Activation("approx_activation")(x) + + x = Conv2D( + filters2, + kernel_size, + padding="same", + kernel_initializer="he_normal", + kernel_regularizer=l2(L2_WEIGHT_DECAY), + bias_regularizer=l2(L2_WEIGHT_DECAY), + )(x) + x = BatchNormalization( + axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, fused=True + )(x, training=training) + + shortcut = Conv2D( + filters2, + (1, 1), + strides=strides, + kernel_initializer="he_normal", + kernel_regularizer=l2(L2_WEIGHT_DECAY), + bias_regularizer=l2(L2_WEIGHT_DECAY), + )(input_tensor) shortcut = BatchNormalization( - axis=bn_axis, - momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,fused=True)( - shortcut, training=training) + axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, fused=True + )(shortcut, training=training) x = add([x, shortcut]) - x = Activation('approx_activation')(x) + x = Activation("approx_activation")(x) return x -def resnet_block(input_tensor, size, kernel_size, filters, stage, conv_strides=(2, 2), training=None): +def resnet_block( + input_tensor, size, kernel_size, filters, stage, conv_strides=(2, 2), training=None +): """A block which applies conv followed by multiple identity blocks. Arguments: @@ -185,14 +202,27 @@ def resnet_block(input_tensor, size, kernel_size, filters, stage, conv_strides=( Output tensor after applying conv and identity blocks. """ - x = conv_building_block(input_tensor, kernel_size, filters, stage=stage, - strides=conv_strides, block='block_0', - training=training) + x = conv_building_block( + input_tensor, + kernel_size, + filters, + stage=stage, + strides=conv_strides, + block="block_0", + training=training, + ) for i in range(size - 1): - x = identity_building_block(x, kernel_size, filters, stage=stage, - block='block_%d' % (i + 1), training=training) + x = identity_building_block( + x, + kernel_size, + filters, + stage=stage, + block="block_%d" % (i + 1), + training=training, + ) return x + def build(): """Instantiates ResNet32 model """ # Parameters for Resnet32 on Cifar-100 @@ -204,44 +234,66 @@ def build(): img_input = layers.Input(shape=input_shape) x = img_input bn_axis = 1 - if tf.keras.backend.image_data_format() == 'channels_last': - bn_axis = 3 + if tf.keras.backend.image_data_format() == "channels_last": + bn_axis = 3 else: - bn_axis = 1 + bn_axis = 1 x = ZeroPadding2D(padding=(1, 1))(x) - x = Conv2D(16, (3, 3), - strides=(1, 1), - padding='valid', - kernel_initializer='he_normal', - kernel_regularizer= - l2(L2_WEIGHT_DECAY), - bias_regularizer= - l2(L2_WEIGHT_DECAY))(x) - x = BatchNormalization(axis=bn_axis, - momentum=BATCH_NORM_DECAY, - epsilon=BATCH_NORM_EPSILON,fused=True)( - x, training=training) - x = Activation('approx_activation')(x) - - x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[16, 16], - stage=2, conv_strides=(1, 1), training=training) - - x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[32, 32], - stage=3, conv_strides=(2, 2), training=training) - - x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[64, 64], - stage=4, conv_strides=(2, 2), training=training) - - x = AveragePooling2D(pool_size=(8,8),strides=(1,1),padding="VALID")(x) - x = Lambda(lambda w: tf.keras.backend.squeeze(w,1))(x) - x = Lambda(lambda w: tf.keras.backend.squeeze(w,1))(x) - x = Dense(classes,activation='softmax', - kernel_initializer='he_normal', - kernel_regularizer= - l2(L2_WEIGHT_DECAY), - bias_regularizer= - l2(L2_WEIGHT_DECAY))(x) + x = Conv2D( + 16, + (3, 3), + strides=(1, 1), + padding="valid", + kernel_initializer="he_normal", + kernel_regularizer=l2(L2_WEIGHT_DECAY), + bias_regularizer=l2(L2_WEIGHT_DECAY), + )(x) + x = BatchNormalization( + axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, fused=True + )(x, training=training) + x = Activation("approx_activation")(x) + + x = resnet_block( + x, + size=num_blocks, + kernel_size=3, + filters=[16, 16], + stage=2, + conv_strides=(1, 1), + training=training, + ) + + x = resnet_block( + x, + size=num_blocks, + kernel_size=3, + filters=[32, 32], + stage=3, + conv_strides=(2, 2), + training=training, + ) + + x = resnet_block( + x, + size=num_blocks, + kernel_size=3, + filters=[64, 64], + stage=4, + conv_strides=(2, 2), + training=training, + ) + + x = AveragePooling2D(pool_size=(8, 8), strides=(1, 1), padding="VALID")(x) + x = Lambda(lambda w: tf.keras.backend.squeeze(w, 1))(x) + x = Lambda(lambda w: tf.keras.backend.squeeze(w, 1))(x) + x = Dense( + classes, + activation="softmax", + kernel_initializer="he_normal", + kernel_regularizer=l2(L2_WEIGHT_DECAY), + bias_regularizer=l2(L2_WEIGHT_DECAY), + )(x) inputs = img_input # Create model. @@ -249,25 +301,31 @@ def build(): return model + def main(): - get_custom_objects().update({'approx_activation': Activation(tf.keras.activations.relu)}) + get_custom_objects().update( + {"approx_activation": Activation(tf.keras.activations.relu)} + ) model = build() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) - imgData = np.full((1,32,32,3), 2.3) - feed_dict={'input_1:0':imgData} - + imgData = np.full((1, 32, 32, 3), 2.3) + feed_dict = {"input_1:0": imgData} + ans = model.predict(imgData) print(ans) - + output_tensor = None gg = tf.get_default_graph() for node in gg.as_graph_def().node: - if node.name == 'dense/Softmax': + if node.name == "dense/Softmax": output_tensor = gg.get_operation_by_name(node.name).outputs[0] - assert(output_tensor is not None) - optimized_graph_def = DumpTFMtData.save_graph_metadata(output_tensor, sess, feed_dict=feed_dict) + assert output_tensor is not None + optimized_graph_def = DumpTFMtData.save_graph_metadata( + output_tensor, sess, feed_dict=feed_dict + ) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/Athos/Networks/ResNet/AccuracyAnalysisHelper/ResNet_main_float_acc.py b/Athos/Networks/ResNet/AccuracyAnalysisHelper/ResNet_main_float_acc.py index bfc9c8b3..d99180cb 100644 --- a/Athos/Networks/ResNet/AccuracyAnalysisHelper/ResNet_main_float_acc.py +++ b/Athos/Networks/ResNet/AccuracyAnalysisHelper/ResNet_main_float_acc.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import numpy import argparse @@ -28,49 +28,56 @@ import tensorflow as tf import _pickle as pickle -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) import ResNet_main batchsize = 1000 N = 50000 -x = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name='input_x') -imgnet_model = ResNet_main.ImagenetModel(50, 'channels_last') +x = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name="input_x") +imgnet_model = ResNet_main.ImagenetModel(50, "channels_last") pred = imgnet_model(x, False) -finalActivationsFileName = 'floating_point_acc.outp' -argmaxOutputFileName = 'floating_point_argmax.outp' +finalActivationsFileName = "floating_point_acc.outp" +argmaxOutputFileName = "floating_point_argmax.outp" with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) + sess.run(tf.global_variables_initializer()) - modelPath = '../PreTrainedModel/resnet_v2_fp32_savedmodel_NHWC/1538687283/variables/variables' - saver = tf.train.Saver() - saver.restore(sess, modelPath) - - with open(finalActivationsFileName,'w') as ff: - pass - with open(argmaxOutputFileName,'w') as ff: - pass - numbatches = N//batchsize - for batchNum in range(numbatches): - startImgNum = (batchNum*batchsize) + 1 - endImgNum = N if (batchNum == numbatches-1) else (((batchNum+1)*batchsize)) - print("Processing images from start,end = {0}, {1}".format(startImgNum, endImgNum)) - images = numpy.zeros(shape=(endImgNum-startImgNum+1,224,224,3)) - for curImgNum in range(startImgNum, endImgNum+1): - with open('./PreProcessedImages/ImageNum_'+str(curImgNum)+'.inp', 'r') as ff: - line = ff.readline() - images[curImgNum-startImgNum] = numpy.reshape(list(map(lambda x : float(x), line.split())), (224,224,3)) - feed_dict = {x: images} - predictions = sess.run(pred, feed_dict=feed_dict) - with open(finalActivationsFileName, 'a') as ff: - with open(argmaxOutputFileName, 'a') as gg: - for i in range(endImgNum-startImgNum+1): - ff.write('Answer for imgCounter = ' + str(startImgNum+i) + '\n') - for elem in numpy.nditer(predictions[i],order='C'): - ff.write(str(elem)+' ') - ff.write('\n\n') - gg.write('Answer for imgCounter = '+str(startImgNum+i)+' is ') - gg.write(str(numpy.argmax(predictions[i]))+'\n') + modelPath = "../PreTrainedModel/resnet_v2_fp32_savedmodel_NHWC/1538687283/variables/variables" + saver = tf.train.Saver() + saver.restore(sess, modelPath) + with open(finalActivationsFileName, "w") as ff: + pass + with open(argmaxOutputFileName, "w") as ff: + pass + numbatches = N // batchsize + for batchNum in range(numbatches): + startImgNum = (batchNum * batchsize) + 1 + endImgNum = ( + N if (batchNum == numbatches - 1) else (((batchNum + 1) * batchsize)) + ) + print( + "Processing images from start,end = {0}, {1}".format(startImgNum, endImgNum) + ) + images = numpy.zeros(shape=(endImgNum - startImgNum + 1, 224, 224, 3)) + for curImgNum in range(startImgNum, endImgNum + 1): + with open( + "./PreProcessedImages/ImageNum_" + str(curImgNum) + ".inp", "r" + ) as ff: + line = ff.readline() + images[curImgNum - startImgNum] = numpy.reshape( + list(map(lambda x: float(x), line.split())), (224, 224, 3) + ) + feed_dict = {x: images} + predictions = sess.run(pred, feed_dict=feed_dict) + with open(finalActivationsFileName, "a") as ff: + with open(argmaxOutputFileName, "a") as gg: + for i in range(endImgNum - startImgNum + 1): + ff.write("Answer for imgCounter = " + str(startImgNum + i) + "\n") + for elem in numpy.nditer(predictions[i], order="C"): + ff.write(str(elem) + " ") + ff.write("\n\n") + gg.write("Answer for imgCounter = " + str(startImgNum + i) + " is ") + gg.write(str(numpy.argmax(predictions[i])) + "\n") diff --git a/Athos/Networks/ResNet/PreProcessingImages/ResNet_preprocess_main.py b/Athos/Networks/ResNet/PreProcessingImages/ResNet_preprocess_main.py index 50bb1bc2..bf8ab837 100644 --- a/Athos/Networks/ResNet/PreProcessingImages/ResNet_preprocess_main.py +++ b/Athos/Networks/ResNet/PreProcessingImages/ResNet_preprocess_main.py @@ -47,387 +47,442 @@ ################################################# # Creating tf record class ImageCoder(object): - """Helper class that provides TensorFlow image coding utilities.""" + """Helper class that provides TensorFlow image coding utilities.""" - def __init__(self): - # Create a single Session to run all image coding calls. - self._sess = tf.Session() + def __init__(self): + # Create a single Session to run all image coding calls. + self._sess = tf.Session() - # Initializes function that converts PNG to JPEG data. - self._png_data = tf.placeholder(dtype=tf.string) - image = tf.image.decode_png(self._png_data, channels=3) - self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100) + # Initializes function that converts PNG to JPEG data. + self._png_data = tf.placeholder(dtype=tf.string) + image = tf.image.decode_png(self._png_data, channels=3) + self._png_to_jpeg = tf.image.encode_jpeg(image, format="rgb", quality=100) - # Initializes function that converts CMYK JPEG data to RGB JPEG data. - self._cmyk_data = tf.placeholder(dtype=tf.string) - image = tf.image.decode_jpeg(self._cmyk_data, channels=0) - self._cmyk_to_rgb = tf.image.encode_jpeg(image, format='rgb', quality=100) + # Initializes function that converts CMYK JPEG data to RGB JPEG data. + self._cmyk_data = tf.placeholder(dtype=tf.string) + image = tf.image.decode_jpeg(self._cmyk_data, channels=0) + self._cmyk_to_rgb = tf.image.encode_jpeg(image, format="rgb", quality=100) - # Initializes function that decodes RGB JPEG data. - self._decode_jpeg_data = tf.placeholder(dtype=tf.string) - self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) + # Initializes function that decodes RGB JPEG data. + self._decode_jpeg_data = tf.placeholder(dtype=tf.string) + self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) - def png_to_jpeg(self, image_data): - return self._sess.run(self._png_to_jpeg, - feed_dict={self._png_data: image_data}) + def png_to_jpeg(self, image_data): + return self._sess.run(self._png_to_jpeg, feed_dict={self._png_data: image_data}) - def cmyk_to_rgb(self, image_data): - return self._sess.run(self._cmyk_to_rgb, - feed_dict={self._cmyk_data: image_data}) + def cmyk_to_rgb(self, image_data): + return self._sess.run( + self._cmyk_to_rgb, feed_dict={self._cmyk_data: image_data} + ) + + def decode_jpeg(self, image_data): + image = self._sess.run( + self._decode_jpeg, feed_dict={self._decode_jpeg_data: image_data} + ) + assert len(image.shape) == 3 + assert image.shape[2] == 3 + return image - def decode_jpeg(self, image_data): - image = self._sess.run(self._decode_jpeg, - feed_dict={self._decode_jpeg_data: image_data}) - assert len(image.shape) == 3 - assert image.shape[2] == 3 - return image def _int64_feature(value): - """Wrapper for inserting int64 features into Example proto.""" - if not isinstance(value, list): - value = [value] - return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) + """Wrapper for inserting int64 features into Example proto.""" + if not isinstance(value, list): + value = [value] + return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) + def _bytes_feature(value): - """Wrapper for inserting bytes features into Example proto.""" - return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + """Wrapper for inserting bytes features into Example proto.""" + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + def _process_image(filename, coder): - """Process a single image file. - - Args: - filename: string, path to an image file e.g., '/path/to/example.JPG'. - coder: instance of ImageCoder to provide TensorFlow image coding utils. - Returns: - image_buffer: string, JPEG encoding of RGB image. - height: integer, image height in pixels. - width: integer, image width in pixels. - """ - # Read the image file. - with tf.gfile.GFile(filename, 'rb') as f: - image_data = f.read() - - # Decode the RGB JPEG. - image = coder.decode_jpeg(image_data) - - # Check that image converted to RGB - assert len(image.shape) == 3 - height = image.shape[0] - width = image.shape[1] - assert image.shape[2] == 3 - - return image_data, height, width + """Process a single image file. + + Args: + filename: string, path to an image file e.g., '/path/to/example.JPG'. + coder: instance of ImageCoder to provide TensorFlow image coding utils. + Returns: + image_buffer: string, JPEG encoding of RGB image. + height: integer, image height in pixels. + width: integer, image width in pixels. + """ + # Read the image file. + with tf.gfile.GFile(filename, "rb") as f: + image_data = f.read() + + # Decode the RGB JPEG. + image = coder.decode_jpeg(image_data) + + # Check that image converted to RGB + assert len(image.shape) == 3 + height = image.shape[0] + width = image.shape[1] + assert image.shape[2] == 3 + + return image_data, height, width + def _convert_to_example(filename, image_buffer, label, synset, height, width): - """Build an Example proto for an example. - - Args: - filename: string, path to an image file, e.g., '/path/to/example.JPG' - image_buffer: string, JPEG encoding of RGB image - label: integer, identifier for the ground truth for the network - synset: string, unique WordNet ID specifying the label, e.g., 'n02323233' - height: integer, image height in pixels - width: integer, image width in pixels - Returns: - Example proto - """ - colorspace = 'RGB' - channels = 3 - image_format = 'JPEG' - - example = tf.train.Example(features=tf.train.Features(feature={ - 'image/height': _int64_feature(height), - 'image/width': _int64_feature(width), - 'image/colorspace': _bytes_feature(colorspace.encode('utf-8')), - 'image/channels': _int64_feature(channels), - 'image/class/label': _int64_feature(label), - 'image/class/synset': _bytes_feature(synset.encode('utf-8')), - 'image/format': _bytes_feature(image_format.encode('utf-8')), - 'image/filename': _bytes_feature(os.path.basename(filename).encode('utf-8')), - 'image/encoded': _bytes_feature(image_buffer)})) - return example + """Build an Example proto for an example. + + Args: + filename: string, path to an image file, e.g., '/path/to/example.JPG' + image_buffer: string, JPEG encoding of RGB image + label: integer, identifier for the ground truth for the network + synset: string, unique WordNet ID specifying the label, e.g., 'n02323233' + height: integer, image height in pixels + width: integer, image width in pixels + Returns: + Example proto + """ + colorspace = "RGB" + channels = 3 + image_format = "JPEG" + + example = tf.train.Example( + features=tf.train.Features( + feature={ + "image/height": _int64_feature(height), + "image/width": _int64_feature(width), + "image/colorspace": _bytes_feature(colorspace.encode("utf-8")), + "image/channels": _int64_feature(channels), + "image/class/label": _int64_feature(label), + "image/class/synset": _bytes_feature(synset.encode("utf-8")), + "image/format": _bytes_feature(image_format.encode("utf-8")), + "image/filename": _bytes_feature( + os.path.basename(filename).encode("utf-8") + ), + "image/encoded": _bytes_feature(image_buffer), + } + ) + ) + return example + + ################################################# ################################################# # Parsing tf record def _parse_example_proto(example_serialized): - """Parses an Example proto containing a training example of an image. - - The output of the build_image_data.py image preprocessing script is a dataset - containing serialized Example protocol buffers. Each Example proto contains - the following fields (values are included as examples): - - image/height: 462 - image/width: 581 - image/colorspace: 'RGB' - image/channels: 3 - image/class/label: 615 - image/class/synset: 'n03623198' - image/class/text: 'knee pad' - image/object/bbox/xmin: 0.1 - image/object/bbox/xmax: 0.9 - image/object/bbox/ymin: 0.2 - image/object/bbox/ymax: 0.6 - image/object/bbox/label: 615 - image/format: 'JPEG' - image/filename: 'ILSVRC2012_val_00041207.JPEG' - image/encoded: - - Args: - example_serialized: scalar Tensor tf.string containing a serialized - Example protocol buffer. - - Returns: - image_buffer: Tensor tf.string containing the contents of a JPEG file. - label: Tensor tf.int32 containing the label. - bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] - where each coordinate is [0, 1) and the coordinates are arranged as - [ymin, xmin, ymax, xmax]. - """ - # Dense features in Example proto. - feature_map = { - 'image/encoded': tf.io.FixedLenFeature([], dtype=tf.string, - default_value=''), - 'image/class/label': tf.io.FixedLenFeature([], dtype=tf.int64, - default_value=-1), - 'image/class/text': tf.io.FixedLenFeature([], dtype=tf.string, - default_value=''), - } - sparse_float32 = tf.io.VarLenFeature(dtype=tf.float32) - # Sparse features in Example proto. - feature_map.update( - {k: sparse_float32 for k in ['image/object/bbox/xmin', - 'image/object/bbox/ymin', - 'image/object/bbox/xmax', - 'image/object/bbox/ymax']}) - - features = tf.io.parse_single_example(serialized=example_serialized, - features=feature_map) - label = tf.cast(features['image/class/label'], dtype=tf.int32) - - xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0) - ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0) - xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0) - ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0) - - # Note that we impose an ordering of (y, x) just to make life difficult. - bbox = tf.concat([ymin, xmin, ymax, xmax], 0) - - # Force the variable number of bounding boxes into the shape - # [1, num_boxes, coords]. - bbox = tf.expand_dims(bbox, 0) - bbox = tf.transpose(a=bbox, perm=[0, 2, 1]) - - return features['image/encoded'], label, bbox + """Parses an Example proto containing a training example of an image. + + The output of the build_image_data.py image preprocessing script is a dataset + containing serialized Example protocol buffers. Each Example proto contains + the following fields (values are included as examples): + + image/height: 462 + image/width: 581 + image/colorspace: 'RGB' + image/channels: 3 + image/class/label: 615 + image/class/synset: 'n03623198' + image/class/text: 'knee pad' + image/object/bbox/xmin: 0.1 + image/object/bbox/xmax: 0.9 + image/object/bbox/ymin: 0.2 + image/object/bbox/ymax: 0.6 + image/object/bbox/label: 615 + image/format: 'JPEG' + image/filename: 'ILSVRC2012_val_00041207.JPEG' + image/encoded: + + Args: + example_serialized: scalar Tensor tf.string containing a serialized + Example protocol buffer. + + Returns: + image_buffer: Tensor tf.string containing the contents of a JPEG file. + label: Tensor tf.int32 containing the label. + bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] + where each coordinate is [0, 1) and the coordinates are arranged as + [ymin, xmin, ymax, xmax]. + """ + # Dense features in Example proto. + feature_map = { + "image/encoded": tf.io.FixedLenFeature([], dtype=tf.string, default_value=""), + "image/class/label": tf.io.FixedLenFeature( + [], dtype=tf.int64, default_value=-1 + ), + "image/class/text": tf.io.FixedLenFeature( + [], dtype=tf.string, default_value="" + ), + } + sparse_float32 = tf.io.VarLenFeature(dtype=tf.float32) + # Sparse features in Example proto. + feature_map.update( + { + k: sparse_float32 + for k in [ + "image/object/bbox/xmin", + "image/object/bbox/ymin", + "image/object/bbox/xmax", + "image/object/bbox/ymax", + ] + } + ) + + features = tf.io.parse_single_example( + serialized=example_serialized, features=feature_map + ) + label = tf.cast(features["image/class/label"], dtype=tf.int32) + + xmin = tf.expand_dims(features["image/object/bbox/xmin"].values, 0) + ymin = tf.expand_dims(features["image/object/bbox/ymin"].values, 0) + xmax = tf.expand_dims(features["image/object/bbox/xmax"].values, 0) + ymax = tf.expand_dims(features["image/object/bbox/ymax"].values, 0) + + # Note that we impose an ordering of (y, x) just to make life difficult. + bbox = tf.concat([ymin, xmin, ymax, xmax], 0) + + # Force the variable number of bounding boxes into the shape + # [1, num_boxes, coords]. + bbox = tf.expand_dims(bbox, 0) + bbox = tf.transpose(a=bbox, perm=[0, 2, 1]) + + return features["image/encoded"], label, bbox + def parse_record(raw_record, is_training, dtype): - """Parses a record containing a training example of an image. - - The input record is parsed into a label and image, and the image is passed - through preprocessing steps (cropping, flipping, and so on). - - Args: - raw_record: scalar Tensor tf.string containing a serialized - Example protocol buffer. - is_training: A boolean denoting whether the input is for training. - dtype: data type to use for images/features. - - Returns: - Tuple with processed image tensor and one-hot-encoded label tensor. - """ - image_buffer, label, bbox = _parse_example_proto(raw_record) - - image = imagenet_preprocessing.preprocess_image( - image_buffer=image_buffer, - bbox=bbox, - output_height=DEFAULT_IMAGE_SIZE, - output_width=DEFAULT_IMAGE_SIZE, - num_channels=NUM_CHANNELS, - is_training=is_training) - image = tf.cast(image, dtype) - - return image, label + """Parses a record containing a training example of an image. + + The input record is parsed into a label and image, and the image is passed + through preprocessing steps (cropping, flipping, and so on). + + Args: + raw_record: scalar Tensor tf.string containing a serialized + Example protocol buffer. + is_training: A boolean denoting whether the input is for training. + dtype: data type to use for images/features. + + Returns: + Tuple with processed image tensor and one-hot-encoded label tensor. + """ + image_buffer, label, bbox = _parse_example_proto(raw_record) + + image = imagenet_preprocessing.preprocess_image( + image_buffer=image_buffer, + bbox=bbox, + output_height=DEFAULT_IMAGE_SIZE, + output_width=DEFAULT_IMAGE_SIZE, + num_channels=NUM_CHANNELS, + is_training=is_training, + ) + image = tf.cast(image, dtype) + + return image, label + + ################################################# ################################################# # Parsing image to get bounding box info class BoundingBox(object): - pass + pass + def GetItem(name, root, index=0): - count = 0 - for item in root.iter(name): - if count == index: - return item.text - count += 1 - # Failed to find "index" occurrence of item. - return -1 + count = 0 + for item in root.iter(name): + if count == index: + return item.text + count += 1 + # Failed to find "index" occurrence of item. + return -1 + def GetInt(name, root, index=0): - return int(GetItem(name, root, index)) + return int(GetItem(name, root, index)) + def FindNumberBoundingBoxes(root): - index = 0 - while True: - if GetInt('xmin', root, index) == -1: - break - index += 1 - return index + index = 0 + while True: + if GetInt("xmin", root, index) == -1: + break + index += 1 + return index + def ProcessXMLAnnotation(xml_file): - """Process a single XML file containing a bounding box.""" - # pylint: disable=broad-except - try: - tree = ET.parse(xml_file) - except Exception: - print('Failed to parse: ' + xml_file, file=sys.stderr) - return None - # pylint: enable=broad-except - root = tree.getroot() - - num_boxes = FindNumberBoundingBoxes(root) - boxes = [] - - for index in range(num_boxes): - box = BoundingBox() - # Grab the 'index' annotation. - box.xmin = GetInt('xmin', root, index) - box.ymin = GetInt('ymin', root, index) - box.xmax = GetInt('xmax', root, index) - box.ymax = GetInt('ymax', root, index) - - box.width = GetInt('width', root) - box.height = GetInt('height', root) - box.filename = GetItem('filename', root) + '.JPEG' - box.label = GetItem('name', root) - - xmin = float(box.xmin) / float(box.width) - xmax = float(box.xmax) / float(box.width) - ymin = float(box.ymin) / float(box.height) - ymax = float(box.ymax) / float(box.height) - - # Some images contain bounding box annotations that - # extend outside of the supplied image. See, e.g. - # n03127925/n03127925_147.xml - # Additionally, for some bounding boxes, the min > max - # or the box is entirely outside of the image. - min_x = min(xmin, xmax) - max_x = max(xmin, xmax) - box.xmin_scaled = min(max(min_x, 0.0), 1.0) - box.xmax_scaled = min(max(max_x, 0.0), 1.0) - - min_y = min(ymin, ymax) - max_y = max(ymin, ymax) - box.ymin_scaled = min(max(min_y, 0.0), 1.0) - box.ymax_scaled = min(max(max_y, 0.0), 1.0) - - boxes.append(box) - - return boxes + """Process a single XML file containing a bounding box.""" + # pylint: disable=broad-except + try: + tree = ET.parse(xml_file) + except Exception: + print("Failed to parse: " + xml_file, file=sys.stderr) + return None + # pylint: enable=broad-except + root = tree.getroot() + + num_boxes = FindNumberBoundingBoxes(root) + boxes = [] + + for index in range(num_boxes): + box = BoundingBox() + # Grab the 'index' annotation. + box.xmin = GetInt("xmin", root, index) + box.ymin = GetInt("ymin", root, index) + box.xmax = GetInt("xmax", root, index) + box.ymax = GetInt("ymax", root, index) + + box.width = GetInt("width", root) + box.height = GetInt("height", root) + box.filename = GetItem("filename", root) + ".JPEG" + box.label = GetItem("name", root) + + xmin = float(box.xmin) / float(box.width) + xmax = float(box.xmax) / float(box.width) + ymin = float(box.ymin) / float(box.height) + ymax = float(box.ymax) / float(box.height) + + # Some images contain bounding box annotations that + # extend outside of the supplied image. See, e.g. + # n03127925/n03127925_147.xml + # Additionally, for some bounding boxes, the min > max + # or the box is entirely outside of the image. + min_x = min(xmin, xmax) + max_x = max(xmin, xmax) + box.xmin_scaled = min(max(min_x, 0.0), 1.0) + box.xmax_scaled = min(max(max_x, 0.0), 1.0) + + min_y = min(ymin, ymax) + max_y = max(ymin, ymax) + box.ymin_scaled = min(max(min_y, 0.0), 1.0) + box.ymax_scaled = min(max(max_y, 0.0), 1.0) + + boxes.append(box) + + return boxes + ################################################# # Old function def CreateTFRecordFromImage(input_filename, output_filename): - coder = ImageCoder() - - writer = tf.python_io.TFRecordWriter(output_filename) - image_buffer, height, width = _process_image(input_filename, coder) - label = 0 - synset = 'n02323233' - example = _convert_to_example(input_filename, image_buffer, label, - synset, height, width) - writer.write(example.SerializeToString()) - writer.close() - -# Old function -def ReadAndPreProcessTFRecord(tfRecordFileName): - reader = tf.TFRecordReader() - filename_queue = tf.train.string_input_producer([tfRecordFileName]) - _, serialized_example = reader.read(filename_queue) - - parsed_image_tensor = parse_record(serialized_example, False, tf.float32) - return parsed_image_tensor - -def get_sample(input_img_filename, input_xml_filename, coder = None): - if coder is None: coder = ImageCoder() - image_buffer, height, width = _process_image(input_img_filename, coder) - boxes = ProcessXMLAnnotation(input_xml_filename) - # print(boxes[0].xmin, boxes[0].ymin, boxes[0].xmax, boxes[0].ymax) - # print(boxes[0].xmin_scaled, boxes[0].ymin_scaled, boxes[0].xmax_scaled, boxes[0].ymax_scaled) + writer = tf.python_io.TFRecordWriter(output_filename) + image_buffer, height, width = _process_image(input_filename, coder) + label = 0 + synset = "n02323233" + example = _convert_to_example( + input_filename, image_buffer, label, synset, height, width + ) + writer.write(example.SerializeToString()) + writer.close() - xmin = tf.expand_dims(boxes[0].xmin_scaled, 0) - ymin = tf.expand_dims(boxes[0].ymin_scaled, 0) - xmax = tf.expand_dims(boxes[0].xmax_scaled, 0) - ymax = tf.expand_dims(boxes[0].ymax_scaled, 0) - # Note that we impose an ordering of (y, x) just to make life difficult. - bbox = tf.concat([ymin, xmin, ymax, xmax], 0) - - # Force the variable number of bounding boxes into the shape - # [1, num_boxes, coords]. - bbox = tf.expand_dims(bbox, 0) - bbox = tf.expand_dims(bbox, 0) - # bbox = tf.transpose(a=bbox, perm=[0, 2, 1]) +# Old function +def ReadAndPreProcessTFRecord(tfRecordFileName): + reader = tf.TFRecordReader() + filename_queue = tf.train.string_input_producer([tfRecordFileName]) + _, serialized_example = reader.read(filename_queue) + + parsed_image_tensor = parse_record(serialized_example, False, tf.float32) + return parsed_image_tensor + + +def get_sample(input_img_filename, input_xml_filename, coder=None): + if coder is None: + coder = ImageCoder() + image_buffer, height, width = _process_image(input_img_filename, coder) + + boxes = ProcessXMLAnnotation(input_xml_filename) + # print(boxes[0].xmin, boxes[0].ymin, boxes[0].xmax, boxes[0].ymax) + # print(boxes[0].xmin_scaled, boxes[0].ymin_scaled, boxes[0].xmax_scaled, boxes[0].ymax_scaled) + + xmin = tf.expand_dims(boxes[0].xmin_scaled, 0) + ymin = tf.expand_dims(boxes[0].ymin_scaled, 0) + xmax = tf.expand_dims(boxes[0].xmax_scaled, 0) + ymax = tf.expand_dims(boxes[0].ymax_scaled, 0) + + # Note that we impose an ordering of (y, x) just to make life difficult. + bbox = tf.concat([ymin, xmin, ymax, xmax], 0) + + # Force the variable number of bounding boxes into the shape + # [1, num_boxes, coords]. + bbox = tf.expand_dims(bbox, 0) + bbox = tf.expand_dims(bbox, 0) + # bbox = tf.transpose(a=bbox, perm=[0, 2, 1]) + + image = imagenet_preprocessing.preprocess_image( + image_buffer=image_buffer, + bbox=bbox, + output_height=DEFAULT_IMAGE_SIZE, + output_width=DEFAULT_IMAGE_SIZE, + num_channels=NUM_CHANNELS, + is_training=False, + ) - image = imagenet_preprocessing.preprocess_image( - image_buffer=image_buffer, - bbox=bbox, - output_height=DEFAULT_IMAGE_SIZE, - output_width=DEFAULT_IMAGE_SIZE, - num_channels=NUM_CHANNELS, - is_training=False) + return image - return image def dumpImageDataFloat(imgData, filename, writeMode): - with open(filename, writeMode) as ff: - for xx in numpy.nditer(imgData, order='C'): - ff.write(str(xx) + ' ') - ff.write('\n\n') + with open(filename, writeMode) as ff: + for xx in numpy.nditer(imgData, order="C"): + ff.write(str(xx) + " ") + ff.write("\n\n") + def main(): - if not((len(sys.argv) >= 7) and (len(sys.argv) <= 8)): - print("Args : ?", file=sys.stderr) - exit(1) - - imgFolderName = sys.argv[1] - bboxFolderName = sys.argv[2] - fileNamePrefix = sys.argv[3] - preProcessedImgFolderName = sys.argv[4] - firstImgNum = int(sys.argv[5]) - lastImgNum = int(sys.argv[6]) - randomSubsetIdxFile = None - if (len(sys.argv) == 8): - randomSubsetIdxFile = sys.argv[7] - - randomIdxToBeChosen = None - if randomSubsetIdxFile: - with open(randomSubsetIdxFile, 'r') as ff: - randomIdxToBeChosen = ff.readlines() - randomIdxToBeChosen = list(map(lambda x : int(x.rstrip()), randomIdxToBeChosen)) - assert(lastImgNum <= len(randomIdxToBeChosen)+1) #Assert that the last img num passed is within bounds - - sess = tf.Session() - sess.run(tf.global_variables_initializer()) - coder = ImageCoder() - for curImgNum in range(firstImgNum, lastImgNum): - if (curImgNum % 100 == 0): - print("CurImgNum = ", curImgNum) - actualImgNum = curImgNum if not(randomIdxToBeChosen) else randomIdxToBeChosen[curImgNum-1] - imgFileName = os.path.join(imgFolderName, fileNamePrefix + "{:08d}".format(actualImgNum) + '.JPEG') - bboxFileName = os.path.join(bboxFolderName, fileNamePrefix + "{:08d}".format(actualImgNum) + '.xml') - - image = get_sample(imgFileName, bboxFileName, coder) - outp = sess.run([image], feed_dict={}) - - saveFilePath = os.path.join(preProcessedImgFolderName, 'ImageNum_' + str(actualImgNum) + '.inp') - dumpImageDataFloat(outp, saveFilePath, 'w') - -if __name__=='__main__': - main() + if not ((len(sys.argv) >= 7) and (len(sys.argv) <= 8)): + print( + "Args : ?", + file=sys.stderr, + ) + exit(1) + + imgFolderName = sys.argv[1] + bboxFolderName = sys.argv[2] + fileNamePrefix = sys.argv[3] + preProcessedImgFolderName = sys.argv[4] + firstImgNum = int(sys.argv[5]) + lastImgNum = int(sys.argv[6]) + randomSubsetIdxFile = None + if len(sys.argv) == 8: + randomSubsetIdxFile = sys.argv[7] + + randomIdxToBeChosen = None + if randomSubsetIdxFile: + with open(randomSubsetIdxFile, "r") as ff: + randomIdxToBeChosen = ff.readlines() + randomIdxToBeChosen = list( + map(lambda x: int(x.rstrip()), randomIdxToBeChosen) + ) + assert ( + lastImgNum <= len(randomIdxToBeChosen) + 1 + ) # Assert that the last img num passed is within bounds + + sess = tf.Session() + sess.run(tf.global_variables_initializer()) + coder = ImageCoder() + for curImgNum in range(firstImgNum, lastImgNum): + if curImgNum % 100 == 0: + print("CurImgNum = ", curImgNum) + actualImgNum = ( + curImgNum + if not (randomIdxToBeChosen) + else randomIdxToBeChosen[curImgNum - 1] + ) + imgFileName = os.path.join( + imgFolderName, fileNamePrefix + "{:08d}".format(actualImgNum) + ".JPEG" + ) + bboxFileName = os.path.join( + bboxFolderName, fileNamePrefix + "{:08d}".format(actualImgNum) + ".xml" + ) + + image = get_sample(imgFileName, bboxFileName, coder) + outp = sess.run([image], feed_dict={}) + + saveFilePath = os.path.join( + preProcessedImgFolderName, "ImageNum_" + str(actualImgNum) + ".inp" + ) + dumpImageDataFloat(outp, saveFilePath, "w") + + +if __name__ == "__main__": + main() diff --git a/Athos/Networks/ResNet/PreProcessingImages/imagenet_preprocessing.py b/Athos/Networks/ResNet/PreProcessingImages/imagenet_preprocessing.py index 0caebe58..934622a2 100644 --- a/Athos/Networks/ResNet/PreProcessingImages/imagenet_preprocessing.py +++ b/Athos/Networks/ResNet/PreProcessingImages/imagenet_preprocessing.py @@ -49,212 +49,217 @@ def _decode_crop_and_flip(image_buffer, bbox, num_channels): - """Crops the given image to a random part of the image, and randomly flips. - - We use the fused decode_and_crop op, which performs better than the two ops - used separately in series, but note that this requires that the image be - passed in as an un-decoded string Tensor. - - Args: - image_buffer: scalar string Tensor representing the raw JPEG image buffer. - bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] - where each coordinate is [0, 1) and the coordinates are arranged as - [ymin, xmin, ymax, xmax]. - num_channels: Integer depth of the image buffer for decoding. - - Returns: - 3-D tensor with cropped image. - - """ - # A large fraction of image datasets contain a human-annotated bounding box - # delineating the region of the image containing the object of interest. We - # choose to create a new bounding box for the object which is a randomly - # distorted version of the human-annotated bounding box that obeys an - # allowed range of aspect ratios, sizes and overlap with the human-annotated - # bounding box. If no box is supplied, then we assume the bounding box is - # the entire image. - sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( - tf.image.extract_jpeg_shape(image_buffer), - bounding_boxes=bbox, - min_object_covered=0.1, - aspect_ratio_range=[0.75, 1.33], - area_range=[0.05, 1.0], - max_attempts=100, - use_image_if_no_bounding_boxes=True) - bbox_begin, bbox_size, _ = sample_distorted_bounding_box - - # Reassemble the bounding box in the format the crop op requires. - offset_y, offset_x, _ = tf.unstack(bbox_begin) - target_height, target_width, _ = tf.unstack(bbox_size) - crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) - - # Use the fused decode and crop op here, which is faster than each in series. - cropped = tf.image.decode_and_crop_jpeg( - image_buffer, crop_window, channels=num_channels) - - # Flip to add a little more random distortion in. - cropped = tf.image.random_flip_left_right(cropped) - return cropped + """Crops the given image to a random part of the image, and randomly flips. + + We use the fused decode_and_crop op, which performs better than the two ops + used separately in series, but note that this requires that the image be + passed in as an un-decoded string Tensor. + + Args: + image_buffer: scalar string Tensor representing the raw JPEG image buffer. + bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] + where each coordinate is [0, 1) and the coordinates are arranged as + [ymin, xmin, ymax, xmax]. + num_channels: Integer depth of the image buffer for decoding. + + Returns: + 3-D tensor with cropped image. + + """ + # A large fraction of image datasets contain a human-annotated bounding box + # delineating the region of the image containing the object of interest. We + # choose to create a new bounding box for the object which is a randomly + # distorted version of the human-annotated bounding box that obeys an + # allowed range of aspect ratios, sizes and overlap with the human-annotated + # bounding box. If no box is supplied, then we assume the bounding box is + # the entire image. + sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( + tf.image.extract_jpeg_shape(image_buffer), + bounding_boxes=bbox, + min_object_covered=0.1, + aspect_ratio_range=[0.75, 1.33], + area_range=[0.05, 1.0], + max_attempts=100, + use_image_if_no_bounding_boxes=True, + ) + bbox_begin, bbox_size, _ = sample_distorted_bounding_box + + # Reassemble the bounding box in the format the crop op requires. + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) + + # Use the fused decode and crop op here, which is faster than each in series. + cropped = tf.image.decode_and_crop_jpeg( + image_buffer, crop_window, channels=num_channels + ) + + # Flip to add a little more random distortion in. + cropped = tf.image.random_flip_left_right(cropped) + return cropped def _central_crop(image, crop_height, crop_width): - """Performs central crops of the given image list. + """Performs central crops of the given image list. - Args: - image: a 3-D image tensor - crop_height: the height of the image following the crop. - crop_width: the width of the image following the crop. + Args: + image: a 3-D image tensor + crop_height: the height of the image following the crop. + crop_width: the width of the image following the crop. - Returns: - 3-D tensor with cropped image. - """ - shape = tf.shape(input=image) - height, width = shape[0], shape[1] + Returns: + 3-D tensor with cropped image. + """ + shape = tf.shape(input=image) + height, width = shape[0], shape[1] - amount_to_be_cropped_h = (height - crop_height) - crop_top = amount_to_be_cropped_h // 2 - amount_to_be_cropped_w = (width - crop_width) - crop_left = amount_to_be_cropped_w // 2 - return tf.slice( - image, [crop_top, crop_left, 0], [crop_height, crop_width, -1]) + amount_to_be_cropped_h = height - crop_height + crop_top = amount_to_be_cropped_h // 2 + amount_to_be_cropped_w = width - crop_width + crop_left = amount_to_be_cropped_w // 2 + return tf.slice(image, [crop_top, crop_left, 0], [crop_height, crop_width, -1]) def _mean_image_subtraction(image, means, num_channels): - """Subtracts the given means from each image channel. + """Subtracts the given means from each image channel. - For example: - means = [123.68, 116.779, 103.939] - image = _mean_image_subtraction(image, means) + For example: + means = [123.68, 116.779, 103.939] + image = _mean_image_subtraction(image, means) - Note that the rank of `image` must be known. + Note that the rank of `image` must be known. - Args: - image: a tensor of size [height, width, C]. - means: a C-vector of values to subtract from each channel. - num_channels: number of color channels in the image that will be distorted. + Args: + image: a tensor of size [height, width, C]. + means: a C-vector of values to subtract from each channel. + num_channels: number of color channels in the image that will be distorted. - Returns: - the centered image. + Returns: + the centered image. - Raises: - ValueError: If the rank of `image` is unknown, if `image` has a rank other - than three or if the number of channels in `image` doesn't match the - number of values in `means`. - """ - if image.get_shape().ndims != 3: - raise ValueError('Input must be of size [height, width, C>0]') + Raises: + ValueError: If the rank of `image` is unknown, if `image` has a rank other + than three or if the number of channels in `image` doesn't match the + number of values in `means`. + """ + if image.get_shape().ndims != 3: + raise ValueError("Input must be of size [height, width, C>0]") - if len(means) != num_channels: - raise ValueError('len(means) must match the number of channels') + if len(means) != num_channels: + raise ValueError("len(means) must match the number of channels") - # We have a 1-D tensor of means; convert to 3-D. - means = tf.expand_dims(tf.expand_dims(means, 0), 0) + # We have a 1-D tensor of means; convert to 3-D. + means = tf.expand_dims(tf.expand_dims(means, 0), 0) - return image - means + return image - means def _smallest_size_at_least(height, width, resize_min): - """Computes new shape with the smallest side equal to `smallest_side`. + """Computes new shape with the smallest side equal to `smallest_side`. - Computes new shape with the smallest side equal to `smallest_side` while - preserving the original aspect ratio. + Computes new shape with the smallest side equal to `smallest_side` while + preserving the original aspect ratio. - Args: - height: an int32 scalar tensor indicating the current height. - width: an int32 scalar tensor indicating the current width. - resize_min: A python integer or scalar `Tensor` indicating the size of - the smallest side after resize. + Args: + height: an int32 scalar tensor indicating the current height. + width: an int32 scalar tensor indicating the current width. + resize_min: A python integer or scalar `Tensor` indicating the size of + the smallest side after resize. - Returns: - new_height: an int32 scalar tensor indicating the new height. - new_width: an int32 scalar tensor indicating the new width. - """ - resize_min = tf.cast(resize_min, tf.float32) + Returns: + new_height: an int32 scalar tensor indicating the new height. + new_width: an int32 scalar tensor indicating the new width. + """ + resize_min = tf.cast(resize_min, tf.float32) - # Convert to floats to make subsequent calculations go smoothly. - height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32) + # Convert to floats to make subsequent calculations go smoothly. + height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32) - smaller_dim = tf.minimum(height, width) - scale_ratio = resize_min / smaller_dim + smaller_dim = tf.minimum(height, width) + scale_ratio = resize_min / smaller_dim - # Convert back to ints to make heights and widths that TF ops will accept. - new_height = tf.cast(height * scale_ratio, tf.int32) - new_width = tf.cast(width * scale_ratio, tf.int32) + # Convert back to ints to make heights and widths that TF ops will accept. + new_height = tf.cast(height * scale_ratio, tf.int32) + new_width = tf.cast(width * scale_ratio, tf.int32) - return new_height, new_width + return new_height, new_width def _aspect_preserving_resize(image, resize_min): - """Resize images preserving the original aspect ratio. + """Resize images preserving the original aspect ratio. - Args: - image: A 3-D image `Tensor`. - resize_min: A python integer or scalar `Tensor` indicating the size of - the smallest side after resize. + Args: + image: A 3-D image `Tensor`. + resize_min: A python integer or scalar `Tensor` indicating the size of + the smallest side after resize. - Returns: - resized_image: A 3-D tensor containing the resized image. - """ - shape = tf.shape(input=image) - height, width = shape[0], shape[1] + Returns: + resized_image: A 3-D tensor containing the resized image. + """ + shape = tf.shape(input=image) + height, width = shape[0], shape[1] - new_height, new_width = _smallest_size_at_least(height, width, resize_min) + new_height, new_width = _smallest_size_at_least(height, width, resize_min) - return _resize_image(image, new_height, new_width) + return _resize_image(image, new_height, new_width) def _resize_image(image, height, width): - """Simple wrapper around tf.resize_images. - - This is primarily to make sure we use the same `ResizeMethod` and other - details each time. - - Args: - image: A 3-D image `Tensor`. - height: The target height for the resized image. - width: The target width for the resized image. - - Returns: - resized_image: A 3-D tensor containing the resized image. The first two - dimensions have the shape [height, width]. - """ - return tf.image.resize_images( - image, [height, width], method=tf.image.ResizeMethod.BILINEAR, - align_corners=False) - - -def preprocess_image(image_buffer, bbox, output_height, output_width, - num_channels, is_training=False): - """Preprocesses the given image. - - Preprocessing includes decoding, cropping, and resizing for both training - and eval images. Training preprocessing, however, introduces some random - distortion of the image to improve accuracy. - - Args: - image_buffer: scalar string Tensor representing the raw JPEG image buffer. - bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] - where each coordinate is [0, 1) and the coordinates are arranged as - [ymin, xmin, ymax, xmax]. - output_height: The height of the image after preprocessing. - output_width: The width of the image after preprocessing. - num_channels: Integer depth of the image buffer for decoding. - is_training: `True` if we're preprocessing the image for training and - `False` otherwise. - - Returns: - A preprocessed image. - """ - if is_training: - # For training, we want to randomize some of the distortions. - image = _decode_crop_and_flip(image_buffer, bbox, num_channels) - image = _resize_image(image, output_height, output_width) - else: - # For validation, we want to decode, resize, then just crop the middle. - image = tf.image.decode_jpeg(image_buffer, channels=num_channels) - image = _aspect_preserving_resize(image, _RESIZE_MIN) - image = _central_crop(image, output_height, output_width) - - image.set_shape([output_height, output_width, num_channels]) - - return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels) + """Simple wrapper around tf.resize_images. + + This is primarily to make sure we use the same `ResizeMethod` and other + details each time. + + Args: + image: A 3-D image `Tensor`. + height: The target height for the resized image. + width: The target width for the resized image. + + Returns: + resized_image: A 3-D tensor containing the resized image. The first two + dimensions have the shape [height, width]. + """ + return tf.image.resize_images( + image, + [height, width], + method=tf.image.ResizeMethod.BILINEAR, + align_corners=False, + ) + + +def preprocess_image( + image_buffer, bbox, output_height, output_width, num_channels, is_training=False +): + """Preprocesses the given image. + + Preprocessing includes decoding, cropping, and resizing for both training + and eval images. Training preprocessing, however, introduces some random + distortion of the image to improve accuracy. + + Args: + image_buffer: scalar string Tensor representing the raw JPEG image buffer. + bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] + where each coordinate is [0, 1) and the coordinates are arranged as + [ymin, xmin, ymax, xmax]. + output_height: The height of the image after preprocessing. + output_width: The width of the image after preprocessing. + num_channels: Integer depth of the image buffer for decoding. + is_training: `True` if we're preprocessing the image for training and + `False` otherwise. + + Returns: + A preprocessed image. + """ + if is_training: + # For training, we want to randomize some of the distortions. + image = _decode_crop_and_flip(image_buffer, bbox, num_channels) + image = _resize_image(image, output_height, output_width) + else: + # For validation, we want to decode, resize, then just crop the middle. + image = tf.image.decode_jpeg(image_buffer, channels=num_channels) + image = _aspect_preserving_resize(image, _RESIZE_MIN) + image = _central_crop(image, output_height, output_width) + + image.set_shape([output_height, output_width, num_channels]) + + return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels) diff --git a/Athos/Networks/ResNet/ResNet_main.py b/Athos/Networks/ResNet/ResNet_main.py index cd1d9d6b..3a8dd938 100644 --- a/Athos/Networks/ResNet/ResNet_main.py +++ b/Athos/Networks/ResNet/ResNet_main.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -23,7 +23,7 @@ Parts of the code in this file was taken from the original model from https://github.com/tensorflow/models/tree/master/official/r1/resnet. -''' +""" import os, sys import time @@ -36,6 +36,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) from tensorflow.python.util import deprecation + deprecation._PRINT_DEPRECATION_WARNINGS = False try: from tensorflow.python.util import module_wrapper as deprecation @@ -43,7 +44,7 @@ from tensorflow.python.util import deprecation_wrapper as deprecation deprecation._PER_MODULE_WARNING_LIMIT = 0 -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "TFCompiler")) import DumpTFMtData NUM_CLASSES = 1001 @@ -51,169 +52,234 @@ ############################################## # Model related functions -def _get_block_sizes(resnet_size): - """Retrieve the size of each block_layer in the ResNet model. - - The number of block layers used for the Resnet model varies according - to the size of the model. This helper grabs the layer set we want, throwing - an error if a non-standard size has been selected. - - Args: - resnet_size: The number of convolutional layers needed in the model. - - Returns: - A list of block sizes to use in building the model. - - Raises: - KeyError: if invalid resnet_size is received. - """ - choices = { - 1: [0, 0, 0, 0], - 18: [2, 2, 2, 2], - 34: [3, 4, 6, 3], - 50: [3, 4, 6, 3], - 101: [3, 4, 23, 3], - 152: [3, 8, 36, 3], - 200: [3, 24, 36, 3] - } - - try: - return choices[resnet_size] - except KeyError: - err = ('Could not find layers for selected Resnet size.\n' - 'Size received: {}; sizes allowed: {}.'.format( - resnet_size, choices.keys())) - raise ValueError(err) -class ImagenetModel(Resnet_Model.Model): - """Model class with appropriate defaults for Imagenet data.""" +def _get_block_sizes(resnet_size): + """Retrieve the size of each block_layer in the ResNet model. - def __init__(self, resnet_size, data_format=None, num_classes=NUM_CLASSES, - resnet_version=Resnet_Model.DEFAULT_VERSION, - dtype=Resnet_Model.DEFAULT_DTYPE): - """These are the parameters that work for Imagenet data. + The number of block layers used for the Resnet model varies according + to the size of the model. This helper grabs the layer set we want, throwing + an error if a non-standard size has been selected. Args: resnet_size: The number of convolutional layers needed in the model. - data_format: Either 'channels_first' or 'channels_last', specifying which - data format to use when setting up the model. - num_classes: The number of output classes needed from the model. This - enables users to extend the same model to their own datasets. - resnet_version: Integer representing which version of the ResNet network - to use. See README for details. Valid values: [1, 2] - dtype: The TensorFlow dtype to use for calculations. + + Returns: + A list of block sizes to use in building the model. + + Raises: + KeyError: if invalid resnet_size is received. """ + choices = { + 1: [0, 0, 0, 0], + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + 200: [3, 24, 36, 3], + } + + try: + return choices[resnet_size] + except KeyError: + err = ( + "Could not find layers for selected Resnet size.\n" + "Size received: {}; sizes allowed: {}.".format(resnet_size, choices.keys()) + ) + raise ValueError(err) + + +class ImagenetModel(Resnet_Model.Model): + """Model class with appropriate defaults for Imagenet data.""" + + def __init__( + self, + resnet_size, + data_format=None, + num_classes=NUM_CLASSES, + resnet_version=Resnet_Model.DEFAULT_VERSION, + dtype=Resnet_Model.DEFAULT_DTYPE, + ): + """These are the parameters that work for Imagenet data. + + Args: + resnet_size: The number of convolutional layers needed in the model. + data_format: Either 'channels_first' or 'channels_last', specifying which + data format to use when setting up the model. + num_classes: The number of output classes needed from the model. This + enables users to extend the same model to their own datasets. + resnet_version: Integer representing which version of the ResNet network + to use. See README for details. Valid values: [1, 2] + dtype: The TensorFlow dtype to use for calculations. + """ + + # For bigger models, we want to use "bottleneck" layers + if resnet_size < 50: + bottleneck = False + else: + bottleneck = True + + super(ImagenetModel, self).__init__( + resnet_size=resnet_size, + bottleneck=bottleneck, + num_classes=num_classes, + num_filters=64, + kernel_size=7, + conv_stride=2, + first_pool_size=3, + first_pool_stride=2, + block_sizes=_get_block_sizes(resnet_size), + block_strides=[1, 2, 2, 2], + resnet_version=resnet_version, + data_format=data_format, + dtype=dtype, + ) - # For bigger models, we want to use "bottleneck" layers - if resnet_size < 50: - bottleneck = False - else: - bottleneck = True - - super(ImagenetModel, self).__init__( - resnet_size=resnet_size, - bottleneck=bottleneck, - num_classes=num_classes, - num_filters=64, - kernel_size=7, - conv_stride=2, - first_pool_size=3, - first_pool_stride=2, - block_sizes=_get_block_sizes(resnet_size), - block_strides=[1, 2, 2, 2], - resnet_version=resnet_version, - data_format=data_format, - dtype=dtype - ) ############################################## -def infer(savePreTrainedWeightsInt, savePreTrainedWeightsFloat, scalingFac, runPrediction, saveImgAndWtData): - x = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name='input_x') - # y = tf.placeholder(tf.int64, shape=(None), name='input_y') - - imgnet_model = ImagenetModel(50, 'channels_last') - pred = imgnet_model(x, False) - pred = tf.argmax(pred, 1) - # correct_pred = tf.equal(numericPred, y) - # accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy') - - with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - - with open('./SampleImages/n02109961_36_enc.pkl', 'rb') as ff: - images = pickle.load(ff) - - numImages = len(images) - print("lenimages = ", numImages) - feed_dict = {x: images} - - output_tensor = None - gg = tf.get_default_graph() - for node in gg.as_graph_def().node: - if node.name == 'ArgMax': - output_tensor = gg.get_operation_by_name(node.name).outputs[0] - - optimized_graph_def = DumpTFMtData.save_graph_metadata(output_tensor, sess, feed_dict) - - if savePreTrainedWeightsInt or savePreTrainedWeightsFloat or runPrediction or saveImgAndWtData: - modelPath = './PreTrainedModel/resnet_v2_fp32_savedmodel_NHWC/1538687283/variables/variables' - saver = tf.train.Saver() - saver.restore(sess, modelPath) - if savePreTrainedWeightsInt or savePreTrainedWeightsFloat or saveImgAndWtData: - DumpTFMtData.updateWeightsForBN(optimized_graph_def, sess, feed_dict) - - predictions = None - - if runPrediction: - print("*************** Starting Prediction****************") - start_time = time.time() - predictions = sess.run([pred], feed_dict=feed_dict) - end_time = time.time() - print("*************** Done Prediction****************") - duration = end_time - start_time - print("Time taken in inference : ", duration) - with open('tf_pred.float','w+') as f: - f.write(DumpTFMtData.numpy_float_array_to_float_val_str(predictions)) - with open('tf_pred.time','w') as f: - f.write(str(round(duration, 2))) - - trainVarsName = [] - for node in optimized_graph_def.node: - if node.op=="VariableV2": - trainVarsName.append(node.name) - trainVars = list(map(lambda x : tf.get_default_graph().get_operation_by_name(x).outputs[0] , trainVarsName)) - if savePreTrainedWeightsInt: - DumpTFMtData.dumpTrainedWeights(sess, trainVars, "model_weights_scale_{}.inp".format(scalingFac), scalingFac, 'w') - if savePreTrainedWeightsFloat: - DumpTFMtData.dumpTrainedWeightsFloat(sess, trainVars, 'model_weights_float.inp', 'w') - if saveImgAndWtData: - DumpTFMtData.dumpImgAndWeightsDataSeparate(sess, images[0], trainVars, "model_input_scale_{}.inp".format(scalingFac), - "model_weights_scale_{}.inp".format(scalingFac), scalingFac) - return predictions + +def infer( + savePreTrainedWeightsInt, + savePreTrainedWeightsFloat, + scalingFac, + runPrediction, + saveImgAndWtData, +): + x = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name="input_x") + # y = tf.placeholder(tf.int64, shape=(None), name='input_y') + + imgnet_model = ImagenetModel(50, "channels_last") + pred = imgnet_model(x, False) + pred = tf.argmax(pred, 1) + # correct_pred = tf.equal(numericPred, y) + # accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy') + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + + with open("./SampleImages/n02109961_36_enc.pkl", "rb") as ff: + images = pickle.load(ff) + + numImages = len(images) + print("lenimages = ", numImages) + feed_dict = {x: images} + + output_tensor = None + gg = tf.get_default_graph() + for node in gg.as_graph_def().node: + if node.name == "ArgMax": + output_tensor = gg.get_operation_by_name(node.name).outputs[0] + + optimized_graph_def = DumpTFMtData.save_graph_metadata( + output_tensor, sess, feed_dict + ) + + if ( + savePreTrainedWeightsInt + or savePreTrainedWeightsFloat + or runPrediction + or saveImgAndWtData + ): + modelPath = "./PreTrainedModel/resnet_v2_fp32_savedmodel_NHWC/1538687283/variables/variables" + saver = tf.train.Saver() + saver.restore(sess, modelPath) + if ( + savePreTrainedWeightsInt + or savePreTrainedWeightsFloat + or saveImgAndWtData + ): + DumpTFMtData.updateWeightsForBN(optimized_graph_def, sess, feed_dict) + + predictions = None + + if runPrediction: + print("*************** Starting Prediction****************") + start_time = time.time() + predictions = sess.run([pred], feed_dict=feed_dict) + end_time = time.time() + print("*************** Done Prediction****************") + duration = end_time - start_time + print("Time taken in inference : ", duration) + with open("tf_pred.float", "w+") as f: + f.write(DumpTFMtData.numpy_float_array_to_float_val_str(predictions)) + with open("tf_pred.time", "w") as f: + f.write(str(round(duration, 2))) + + trainVarsName = [] + for node in optimized_graph_def.node: + if node.op == "VariableV2": + trainVarsName.append(node.name) + trainVars = list( + map( + lambda x: tf.get_default_graph().get_operation_by_name(x).outputs[0], + trainVarsName, + ) + ) + if savePreTrainedWeightsInt: + DumpTFMtData.dumpTrainedWeights( + sess, + trainVars, + "model_weights_scale_{}.inp".format(scalingFac), + scalingFac, + "w", + ) + if savePreTrainedWeightsFloat: + DumpTFMtData.dumpTrainedWeightsFloat( + sess, trainVars, "model_weights_float.inp", "w" + ) + if saveImgAndWtData: + DumpTFMtData.dumpImgAndWeightsDataSeparate( + sess, + images[0], + trainVars, + "model_input_scale_{}.inp".format(scalingFac), + "model_weights_scale_{}.inp".format(scalingFac), + scalingFac, + ) + return predictions + def parseArgs(): - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser() + + parser.add_argument( + "--savePreTrainedWeightsInt", + type=bool, + default=False, + help="savePreTrainedWeightsInt", + ) + parser.add_argument( + "--savePreTrainedWeightsFloat", + type=bool, + default=False, + help="savePreTrainedWeightsFloat", + ) + parser.add_argument("--scalingFac", type=int, default=15, help="scalingFac") + parser.add_argument( + "--runPrediction", type=bool, default=False, help="runPrediction" + ) + parser.add_argument( + "--saveImgAndWtData", type=bool, default=False, help="saveImgAndWtData" + ) - parser.add_argument("--savePreTrainedWeightsInt", type=bool, default=False, help="savePreTrainedWeightsInt") - parser.add_argument("--savePreTrainedWeightsFloat", type=bool, default=False, help="savePreTrainedWeightsFloat") - parser.add_argument("--scalingFac", type=int, default=15, help="scalingFac") - parser.add_argument("--runPrediction", type=bool, default=False, help="runPrediction") - parser.add_argument("--saveImgAndWtData", type=bool, default=False, help="saveImgAndWtData") + args = parser.parse_args() + return args - args = parser.parse_args() - return args def main(): - pred = None - args = parseArgs() - pred = infer(args.savePreTrainedWeightsInt, - args.savePreTrainedWeightsFloat, - args.scalingFac, - args.runPrediction, - args.saveImgAndWtData) - print("Prediction = ", pred) - return pred - -if __name__=='__main__': - pred = main() + pred = None + args = parseArgs() + pred = infer( + args.savePreTrainedWeightsInt, + args.savePreTrainedWeightsFloat, + args.scalingFac, + args.runPrediction, + args.saveImgAndWtData, + ) + print("Prediction = ", pred) + return pred + + +if __name__ == "__main__": + pred = main() diff --git a/Athos/Networks/ResNet/Resnet_Model.py b/Athos/Networks/ResNet/Resnet_Model.py index 1eff8e90..0c375b45 100644 --- a/Athos/Networks/ResNet/Resnet_Model.py +++ b/Athos/Networks/ResNet/Resnet_Model.py @@ -45,500 +45,598 @@ # Convenience functions for building the ResNet model. ################################################################################ def batch_norm(inputs, training, data_format): - """Performs a batch normalization using a standard set of parameters.""" - # We set fused=True for a significant performance boost. See - # https://www.tensorflow.org/performance/performance_guide#common_fused_ops - return tf.layers.batch_normalization( - inputs=inputs, axis=1 if data_format == 'channels_first' else 3, - momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON, center=True, - scale=True, training=training, fused=True) - # return inputs + """Performs a batch normalization using a standard set of parameters.""" + # We set fused=True for a significant performance boost. See + # https://www.tensorflow.org/performance/performance_guide#common_fused_ops + return tf.layers.batch_normalization( + inputs=inputs, + axis=1 if data_format == "channels_first" else 3, + momentum=_BATCH_NORM_DECAY, + epsilon=_BATCH_NORM_EPSILON, + center=True, + scale=True, + training=training, + fused=True, + ) + # return inputs def fixed_padding(inputs, kernel_size, data_format): - """Pads the input along the spatial dimensions independently of input size. - - Args: - inputs: A tensor of size [batch, channels, height_in, width_in] or - [batch, height_in, width_in, channels] depending on data_format. - kernel_size: The kernel to be used in the conv2d or max_pool2d operation. - Should be a positive integer. - data_format: The input format ('channels_last' or 'channels_first'). - - Returns: - A tensor with the same format as the input with the data either intact - (if kernel_size == 1) or padded (if kernel_size > 1). - """ - pad_total = kernel_size - 1 - pad_beg = pad_total // 2 - pad_end = pad_total - pad_beg - - if data_format == 'channels_first': - padded_inputs = tf.pad(tensor=inputs, - paddings=[[0, 0], [0, 0], [pad_beg, pad_end], - [pad_beg, pad_end]]) - else: - padded_inputs = tf.pad(tensor=inputs, - paddings=[[0, 0], [pad_beg, pad_end], - [pad_beg, pad_end], [0, 0]]) - return padded_inputs + """Pads the input along the spatial dimensions independently of input size. + Args: + inputs: A tensor of size [batch, channels, height_in, width_in] or + [batch, height_in, width_in, channels] depending on data_format. + kernel_size: The kernel to be used in the conv2d or max_pool2d operation. + Should be a positive integer. + data_format: The input format ('channels_last' or 'channels_first'). -def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format): - """Strided 2-D convolution with explicit padding.""" - # The padding is consistent and is based only on `kernel_size`, not on the - # dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone). - if strides > 1: - inputs = fixed_padding(inputs, kernel_size, data_format) + Returns: + A tensor with the same format as the input with the data either intact + (if kernel_size == 1) or padded (if kernel_size > 1). + """ + pad_total = kernel_size - 1 + pad_beg = pad_total // 2 + pad_end = pad_total - pad_beg + + if data_format == "channels_first": + padded_inputs = tf.pad( + tensor=inputs, + paddings=[[0, 0], [0, 0], [pad_beg, pad_end], [pad_beg, pad_end]], + ) + else: + padded_inputs = tf.pad( + tensor=inputs, + paddings=[[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]], + ) + return padded_inputs - return tf.layers.conv2d( - inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides, - padding=('SAME' if strides == 1 else 'VALID'), use_bias=False, - kernel_initializer=tf.constant_initializer(value=0.001), - data_format=data_format) + +def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format): + """Strided 2-D convolution with explicit padding.""" + # The padding is consistent and is based only on `kernel_size`, not on the + # dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone). + if strides > 1: + inputs = fixed_padding(inputs, kernel_size, data_format) + + return tf.layers.conv2d( + inputs=inputs, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=("SAME" if strides == 1 else "VALID"), + use_bias=False, + kernel_initializer=tf.constant_initializer(value=0.001), + data_format=data_format, + ) ################################################################################ # ResNet block definitions. ################################################################################ -def _building_block_v1(inputs, filters, training, projection_shortcut, strides, - data_format): - """A single block for ResNet v1, without a bottleneck. - - Convolution then batch normalization then ReLU as described by: - Deep Residual Learning for Image Recognition - https://arxiv.org/pdf/1512.03385.pdf - by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015. - - Args: - inputs: A tensor of size [batch, channels, height_in, width_in] or - [batch, height_in, width_in, channels] depending on data_format. - filters: The number of filters for the convolutions. - training: A Boolean for whether the model is in training or inference - mode. Needed for batch normalization. - projection_shortcut: The function to use for projection shortcuts - (typically a 1x1 convolution when downsampling the input). - strides: The block's stride. If greater than 1, this block will ultimately - downsample the input. - data_format: The input format ('channels_last' or 'channels_first'). - - Returns: - The output tensor of the block; shape should match inputs. - """ - shortcut = inputs - - if projection_shortcut is not None: - shortcut = projection_shortcut(inputs) - shortcut = batch_norm(inputs=shortcut, training=training, - data_format=data_format) - - inputs = conv2d_fixed_padding( - inputs=inputs, filters=filters, kernel_size=3, strides=strides, - data_format=data_format) - inputs = batch_norm(inputs, training, data_format) - inputs = tf.nn.relu(inputs) - - inputs = conv2d_fixed_padding( - inputs=inputs, filters=filters, kernel_size=3, strides=1, - data_format=data_format) - inputs = batch_norm(inputs, training, data_format) - inputs += shortcut - inputs = tf.nn.relu(inputs) - - return inputs - - -def _building_block_v2(inputs, filters, training, projection_shortcut, strides, - data_format): - """A single block for ResNet v2, without a bottleneck. - - Batch normalization then ReLu then convolution as described by: - Identity Mappings in Deep Residual Networks - https://arxiv.org/pdf/1603.05027.pdf - by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016. - - Args: - inputs: A tensor of size [batch, channels, height_in, width_in] or - [batch, height_in, width_in, channels] depending on data_format. - filters: The number of filters for the convolutions. - training: A Boolean for whether the model is in training or inference - mode. Needed for batch normalization. - projection_shortcut: The function to use for projection shortcuts - (typically a 1x1 convolution when downsampling the input). - strides: The block's stride. If greater than 1, this block will ultimately - downsample the input. - data_format: The input format ('channels_last' or 'channels_first'). - - Returns: - The output tensor of the block; shape should match inputs. - """ - shortcut = inputs - inputs = batch_norm(inputs, training, data_format) - inputs = tf.nn.relu(inputs) - - # The projection shortcut should come after the first batch norm and ReLU - # since it performs a 1x1 convolution. - if projection_shortcut is not None: - shortcut = projection_shortcut(inputs) - - inputs = conv2d_fixed_padding( - inputs=inputs, filters=filters, kernel_size=3, strides=strides, - data_format=data_format) - - inputs = batch_norm(inputs, training, data_format) - inputs = tf.nn.relu(inputs) - inputs = conv2d_fixed_padding( - inputs=inputs, filters=filters, kernel_size=3, strides=1, - data_format=data_format) - - return inputs + shortcut - - -def _bottleneck_block_v1(inputs, filters, training, projection_shortcut, - strides, data_format): - """A single block for ResNet v1, with a bottleneck. - - Similar to _building_block_v1(), except using the "bottleneck" blocks - described in: - Convolution then batch normalization then ReLU as described by: - Deep Residual Learning for Image Recognition - https://arxiv.org/pdf/1512.03385.pdf - by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015. +def _building_block_v1( + inputs, filters, training, projection_shortcut, strides, data_format +): + """A single block for ResNet v1, without a bottleneck. - Args: - inputs: A tensor of size [batch, channels, height_in, width_in] or - [batch, height_in, width_in, channels] depending on data_format. - filters: The number of filters for the convolutions. - training: A Boolean for whether the model is in training or inference - mode. Needed for batch normalization. - projection_shortcut: The function to use for projection shortcuts - (typically a 1x1 convolution when downsampling the input). - strides: The block's stride. If greater than 1, this block will ultimately - downsample the input. - data_format: The input format ('channels_last' or 'channels_first'). - - Returns: - The output tensor of the block; shape should match inputs. - """ - shortcut = inputs - - if projection_shortcut is not None: - shortcut = projection_shortcut(inputs) - shortcut = batch_norm(inputs=shortcut, training=training, - data_format=data_format) - - inputs = conv2d_fixed_padding( - inputs=inputs, filters=filters, kernel_size=1, strides=1, - data_format=data_format) - inputs = batch_norm(inputs, training, data_format) - inputs = tf.nn.relu(inputs) - - inputs = conv2d_fixed_padding( - inputs=inputs, filters=filters, kernel_size=3, strides=strides, - data_format=data_format) - inputs = batch_norm(inputs, training, data_format) - inputs = tf.nn.relu(inputs) - - inputs = conv2d_fixed_padding( - inputs=inputs, filters=4 * filters, kernel_size=1, strides=1, - data_format=data_format) - inputs = batch_norm(inputs, training, data_format) - inputs += shortcut - inputs = tf.nn.relu(inputs) - - return inputs - - -def _bottleneck_block_v2(inputs, filters, training, projection_shortcut, - strides, data_format): - """A single block for ResNet v2, with a bottleneck. - - Similar to _building_block_v2(), except using the "bottleneck" blocks - described in: Convolution then batch normalization then ReLU as described by: Deep Residual Learning for Image Recognition https://arxiv.org/pdf/1512.03385.pdf by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015. - Adapted to the ordering conventions of: + Args: + inputs: A tensor of size [batch, channels, height_in, width_in] or + [batch, height_in, width_in, channels] depending on data_format. + filters: The number of filters for the convolutions. + training: A Boolean for whether the model is in training or inference + mode. Needed for batch normalization. + projection_shortcut: The function to use for projection shortcuts + (typically a 1x1 convolution when downsampling the input). + strides: The block's stride. If greater than 1, this block will ultimately + downsample the input. + data_format: The input format ('channels_last' or 'channels_first'). + + Returns: + The output tensor of the block; shape should match inputs. + """ + shortcut = inputs + + if projection_shortcut is not None: + shortcut = projection_shortcut(inputs) + shortcut = batch_norm( + inputs=shortcut, training=training, data_format=data_format + ) + + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=filters, + kernel_size=3, + strides=strides, + data_format=data_format, + ) + inputs = batch_norm(inputs, training, data_format) + inputs = tf.nn.relu(inputs) + + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=filters, + kernel_size=3, + strides=1, + data_format=data_format, + ) + inputs = batch_norm(inputs, training, data_format) + inputs += shortcut + inputs = tf.nn.relu(inputs) + + return inputs + + +def _building_block_v2( + inputs, filters, training, projection_shortcut, strides, data_format +): + """A single block for ResNet v2, without a bottleneck. + Batch normalization then ReLu then convolution as described by: Identity Mappings in Deep Residual Networks https://arxiv.org/pdf/1603.05027.pdf by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016. - Args: - inputs: A tensor of size [batch, channels, height_in, width_in] or - [batch, height_in, width_in, channels] depending on data_format. - filters: The number of filters for the convolutions. - training: A Boolean for whether the model is in training or inference - mode. Needed for batch normalization. - projection_shortcut: The function to use for projection shortcuts - (typically a 1x1 convolution when downsampling the input). - strides: The block's stride. If greater than 1, this block will ultimately - downsample the input. - data_format: The input format ('channels_last' or 'channels_first'). - - Returns: - The output tensor of the block; shape should match inputs. - """ - shortcut = inputs - inputs = batch_norm(inputs, training, data_format) - inputs = tf.nn.relu(inputs) - - # The projection shortcut should come after the first batch norm and ReLU - # since it performs a 1x1 convolution. - if projection_shortcut is not None: - shortcut = projection_shortcut(inputs) - - inputs = conv2d_fixed_padding( - inputs=inputs, filters=filters, kernel_size=1, strides=1, - data_format=data_format) - - inputs = batch_norm(inputs, training, data_format) - inputs = tf.nn.relu(inputs) - inputs = conv2d_fixed_padding( - inputs=inputs, filters=filters, kernel_size=3, strides=strides, - data_format=data_format) - - inputs = batch_norm(inputs, training, data_format) - inputs = tf.nn.relu(inputs) - inputs = conv2d_fixed_padding( - inputs=inputs, filters=4 * filters, kernel_size=1, strides=1, - data_format=data_format) - - return inputs + shortcut - - -def block_layer(inputs, filters, bottleneck, block_fn, blocks, strides, - training, name, data_format): - """Creates one layer of blocks for the ResNet model. - - Args: - inputs: A tensor of size [batch, channels, height_in, width_in] or - [batch, height_in, width_in, channels] depending on data_format. - filters: The number of filters for the first convolution of the layer. - bottleneck: Is the block created a bottleneck block. - block_fn: The block to use within the model, either `building_block` or - `bottleneck_block`. - blocks: The number of blocks contained in the layer. - strides: The stride to use for the first convolution of the layer. If - greater than 1, this layer will ultimately downsample the input. - training: Either True or False, whether we are currently training the - model. Needed for batch norm. - name: A string name for the tensor output of the block layer. - data_format: The input format ('channels_last' or 'channels_first'). - - Returns: - The output tensor of the block layer. - """ - - # Bottleneck blocks end with 4x the number of filters as they start with - filters_out = filters * 4 if bottleneck else filters - - def projection_shortcut(inputs): - return conv2d_fixed_padding( - inputs=inputs, filters=filters_out, kernel_size=1, strides=strides, - data_format=data_format) - - # Only the first block per block_layer uses projection_shortcut and strides - inputs = block_fn(inputs, filters, training, projection_shortcut, strides, - data_format) - - for _ in range(1, blocks): - inputs = block_fn(inputs, filters, training, None, 1, data_format) - - return tf.identity(inputs, name) - - -class Model(object): - """Base class for building the Resnet Model.""" - - def __init__(self, resnet_size, bottleneck, num_classes, num_filters, - kernel_size, - conv_stride, first_pool_size, first_pool_stride, - block_sizes, block_strides, - resnet_version=DEFAULT_VERSION, data_format=None, - dtype=DEFAULT_DTYPE): - """Creates a model for classifying an image. - Args: - resnet_size: A single integer for the size of the ResNet model. - bottleneck: Use regular blocks or bottleneck blocks. - num_classes: The number of classes used as labels. - num_filters: The number of filters to use for the first block layer - of the model. This number is then doubled for each subsequent block - layer. - kernel_size: The kernel size to use for convolution. - conv_stride: stride size for the initial convolutional layer - first_pool_size: Pool size to be used for the first pooling layer. - If none, the first pooling layer is skipped. - first_pool_stride: stride size for the first pooling layer. Not used - if first_pool_size is None. - block_sizes: A list containing n values, where n is the number of sets of - block layers desired. Each value should be the number of blocks in the - i-th set. - block_strides: List of integers representing the desired stride size for - each of the sets of block layers. Should be same length as block_sizes. - resnet_version: Integer representing which version of the ResNet network - to use. See README for details. Valid values: [1, 2] - data_format: Input format ('channels_last', 'channels_first', or None). - If set to None, the format is dependent on whether a GPU is available. - dtype: The TensorFlow dtype to use for calculations. If not specified - tf.float32 is used. - - Raises: - ValueError: if invalid version is selected. + inputs: A tensor of size [batch, channels, height_in, width_in] or + [batch, height_in, width_in, channels] depending on data_format. + filters: The number of filters for the convolutions. + training: A Boolean for whether the model is in training or inference + mode. Needed for batch normalization. + projection_shortcut: The function to use for projection shortcuts + (typically a 1x1 convolution when downsampling the input). + strides: The block's stride. If greater than 1, this block will ultimately + downsample the input. + data_format: The input format ('channels_last' or 'channels_first'). + + Returns: + The output tensor of the block; shape should match inputs. """ - self.resnet_size = resnet_size - - if not data_format: - data_format = ( - 'channels_first' if tf.test.is_built_with_cuda() else 'channels_last') - - self.resnet_version = resnet_version - if resnet_version not in (1, 2): - raise ValueError( - 'Resnet version should be 1 or 2. See README for citations.') - - self.bottleneck = bottleneck - if bottleneck: - if resnet_version == 1: - self.block_fn = _bottleneck_block_v1 - else: - self.block_fn = _bottleneck_block_v2 - else: - if resnet_version == 1: - self.block_fn = _building_block_v1 - else: - self.block_fn = _building_block_v2 - - if dtype not in ALLOWED_TYPES: - raise ValueError('dtype must be one of: {}'.format(ALLOWED_TYPES)) - - self.data_format = data_format - self.num_classes = num_classes - self.num_filters = num_filters - self.kernel_size = kernel_size - self.conv_stride = conv_stride - self.first_pool_size = first_pool_size - self.first_pool_stride = first_pool_stride - self.block_sizes = block_sizes - self.block_strides = block_strides - self.dtype = dtype - self.pre_activation = resnet_version == 2 - - def _custom_dtype_getter(self, getter, name, shape=None, dtype=DEFAULT_DTYPE, - *args, **kwargs): - """Creates variables in fp32, then casts to fp16 if necessary. - - This function is a custom getter. A custom getter is a function with the - same signature as tf.get_variable, except it has an additional getter - parameter. Custom getters can be passed as the `custom_getter` parameter of - tf.variable_scope. Then, tf.get_variable will call the custom getter, - instead of directly getting a variable itself. This can be used to change - the types of variables that are retrieved with tf.get_variable. - The `getter` parameter is the underlying variable getter, that would have - been called if no custom getter was used. Custom getters typically get a - variable with `getter`, then modify it in some way. - - This custom getter will create an fp32 variable. If a low precision - (e.g. float16) variable was requested it will then cast the variable to the - requested dtype. The reason we do not directly create variables in low - precision dtypes is that applying small gradients to such variables may - cause the variable not to change. + shortcut = inputs + inputs = batch_norm(inputs, training, data_format) + inputs = tf.nn.relu(inputs) + + # The projection shortcut should come after the first batch norm and ReLU + # since it performs a 1x1 convolution. + if projection_shortcut is not None: + shortcut = projection_shortcut(inputs) + + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=filters, + kernel_size=3, + strides=strides, + data_format=data_format, + ) + + inputs = batch_norm(inputs, training, data_format) + inputs = tf.nn.relu(inputs) + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=filters, + kernel_size=3, + strides=1, + data_format=data_format, + ) + + return inputs + shortcut + + +def _bottleneck_block_v1( + inputs, filters, training, projection_shortcut, strides, data_format +): + """A single block for ResNet v1, with a bottleneck. + + Similar to _building_block_v1(), except using the "bottleneck" blocks + described in: + Convolution then batch normalization then ReLU as described by: + Deep Residual Learning for Image Recognition + https://arxiv.org/pdf/1512.03385.pdf + by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015. Args: - getter: The underlying variable getter, that has the same signature as - tf.get_variable and returns a variable. - name: The name of the variable to get. - shape: The shape of the variable to get. - dtype: The dtype of the variable to get. Note that if this is a low - precision dtype, the variable will be created as a tf.float32 variable, - then cast to the appropriate dtype - *args: Additional arguments to pass unmodified to getter. - **kwargs: Additional keyword arguments to pass unmodified to getter. + inputs: A tensor of size [batch, channels, height_in, width_in] or + [batch, height_in, width_in, channels] depending on data_format. + filters: The number of filters for the convolutions. + training: A Boolean for whether the model is in training or inference + mode. Needed for batch normalization. + projection_shortcut: The function to use for projection shortcuts + (typically a 1x1 convolution when downsampling the input). + strides: The block's stride. If greater than 1, this block will ultimately + downsample the input. + data_format: The input format ('channels_last' or 'channels_first'). Returns: - A variable which is cast to fp16 if necessary. + The output tensor of the block; shape should match inputs. """ + shortcut = inputs + + if projection_shortcut is not None: + shortcut = projection_shortcut(inputs) + shortcut = batch_norm( + inputs=shortcut, training=training, data_format=data_format + ) + + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=filters, + kernel_size=1, + strides=1, + data_format=data_format, + ) + inputs = batch_norm(inputs, training, data_format) + inputs = tf.nn.relu(inputs) + + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=filters, + kernel_size=3, + strides=strides, + data_format=data_format, + ) + inputs = batch_norm(inputs, training, data_format) + inputs = tf.nn.relu(inputs) + + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=4 * filters, + kernel_size=1, + strides=1, + data_format=data_format, + ) + inputs = batch_norm(inputs, training, data_format) + inputs += shortcut + inputs = tf.nn.relu(inputs) + + return inputs + + +def _bottleneck_block_v2( + inputs, filters, training, projection_shortcut, strides, data_format +): + """A single block for ResNet v2, with a bottleneck. + + Similar to _building_block_v2(), except using the "bottleneck" blocks + described in: + Convolution then batch normalization then ReLU as described by: + Deep Residual Learning for Image Recognition + https://arxiv.org/pdf/1512.03385.pdf + by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015. + + Adapted to the ordering conventions of: + Batch normalization then ReLu then convolution as described by: + Identity Mappings in Deep Residual Networks + https://arxiv.org/pdf/1603.05027.pdf + by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016. - if dtype in CASTABLE_TYPES: - var = getter(name, shape, tf.float32, *args, **kwargs) - return tf.cast(var, dtype=dtype, name=name + '_cast') - else: - return getter(name, shape, dtype, *args, **kwargs) + Args: + inputs: A tensor of size [batch, channels, height_in, width_in] or + [batch, height_in, width_in, channels] depending on data_format. + filters: The number of filters for the convolutions. + training: A Boolean for whether the model is in training or inference + mode. Needed for batch normalization. + projection_shortcut: The function to use for projection shortcuts + (typically a 1x1 convolution when downsampling the input). + strides: The block's stride. If greater than 1, this block will ultimately + downsample the input. + data_format: The input format ('channels_last' or 'channels_first'). - def _model_variable_scope(self): - """Returns a variable scope that the model should be created under. + Returns: + The output tensor of the block; shape should match inputs. + """ + shortcut = inputs + inputs = batch_norm(inputs, training, data_format) + inputs = tf.nn.relu(inputs) + + # The projection shortcut should come after the first batch norm and ReLU + # since it performs a 1x1 convolution. + if projection_shortcut is not None: + shortcut = projection_shortcut(inputs) + + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=filters, + kernel_size=1, + strides=1, + data_format=data_format, + ) + + inputs = batch_norm(inputs, training, data_format) + inputs = tf.nn.relu(inputs) + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=filters, + kernel_size=3, + strides=strides, + data_format=data_format, + ) + + inputs = batch_norm(inputs, training, data_format) + inputs = tf.nn.relu(inputs) + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=4 * filters, + kernel_size=1, + strides=1, + data_format=data_format, + ) + + return inputs + shortcut + + +def block_layer( + inputs, filters, bottleneck, block_fn, blocks, strides, training, name, data_format +): + """Creates one layer of blocks for the ResNet model. - If self.dtype is a castable type, model variable will be created in fp32 - then cast to self.dtype before being used. + Args: + inputs: A tensor of size [batch, channels, height_in, width_in] or + [batch, height_in, width_in, channels] depending on data_format. + filters: The number of filters for the first convolution of the layer. + bottleneck: Is the block created a bottleneck block. + block_fn: The block to use within the model, either `building_block` or + `bottleneck_block`. + blocks: The number of blocks contained in the layer. + strides: The stride to use for the first convolution of the layer. If + greater than 1, this layer will ultimately downsample the input. + training: Either True or False, whether we are currently training the + model. Needed for batch norm. + name: A string name for the tensor output of the block layer. + data_format: The input format ('channels_last' or 'channels_first'). Returns: - A variable scope for the model. + The output tensor of the block layer. """ - return tf.variable_scope('resnet_model', - custom_getter=self._custom_dtype_getter) + # Bottleneck blocks end with 4x the number of filters as they start with + filters_out = filters * 4 if bottleneck else filters - def __call__(self, inputs, training): - """Add operations to classify a batch of input images. + def projection_shortcut(inputs): + return conv2d_fixed_padding( + inputs=inputs, + filters=filters_out, + kernel_size=1, + strides=strides, + data_format=data_format, + ) - Args: - inputs: A Tensor representing a batch of input images. - training: A boolean. Set to True to add operations required only when - training the classifier. + # Only the first block per block_layer uses projection_shortcut and strides + inputs = block_fn( + inputs, filters, training, projection_shortcut, strides, data_format + ) - Returns: - A logits Tensor with shape [, self.num_classes]. - """ + for _ in range(1, blocks): + inputs = block_fn(inputs, filters, training, None, 1, data_format) + + return tf.identity(inputs, name) - with self._model_variable_scope(): - if self.data_format == 'channels_first': - # Convert the inputs from channels_last (NHWC) to channels_first (NCHW). - # This provides a large performance boost on GPU. See - # https://www.tensorflow.org/performance/performance_guide#data_formats - inputs = tf.transpose(a=inputs, perm=[0, 3, 1, 2]) - - inputs = conv2d_fixed_padding( - inputs=inputs, filters=self.num_filters, kernel_size=self.kernel_size, - strides=self.conv_stride, data_format=self.data_format) - inputs = tf.identity(inputs, 'initial_conv') - - # We do not include batch normalization or activation functions in V2 - # for the initial conv1 because the first ResNet unit will perform these - # for both the shortcut and non-shortcut paths as part of the first - # block's projection. Cf. Appendix of [2]. - if self.resnet_version == 1: - inputs = batch_norm(inputs, training, self.data_format) - inputs = tf.nn.relu(inputs) - - if self.first_pool_size: - inputs = tf.layers.max_pooling2d( - inputs=inputs, pool_size=self.first_pool_size, - strides=self.first_pool_stride, padding='SAME', - data_format=self.data_format) - inputs = tf.identity(inputs, 'initial_max_pool') - - for i, num_blocks in enumerate(self.block_sizes): - num_filters = self.num_filters * (2**i) - inputs = block_layer( - inputs=inputs, filters=num_filters, bottleneck=self.bottleneck, - block_fn=self.block_fn, blocks=num_blocks, - strides=self.block_strides[i], training=training, - name='block_layer{}'.format(i + 1), data_format=self.data_format) - - # Only apply the BN and ReLU for model that does pre_activation in each - # building/bottleneck block, eg resnet V2. - if self.pre_activation: - inputs = batch_norm(inputs, training, self.data_format) - inputs = tf.nn.relu(inputs) - - axes = [2, 3] if self.data_format == 'channels_first' else [1, 2] - inputs = tf.nn.avg_pool(inputs, ksize=[1,7,7,1], strides=[1,1,1,1], padding='VALID') - inputs = tf.identity(inputs, 'final_reduce_mean') - inputs = tf.squeeze(inputs, axes) - inputs = tf.layers.dense(inputs=inputs, units=self.num_classes, kernel_initializer=tf.constant_initializer(value=0.001)) - - inputs = tf.identity(inputs, 'final_dense') - return inputs + +class Model(object): + """Base class for building the Resnet Model.""" + + def __init__( + self, + resnet_size, + bottleneck, + num_classes, + num_filters, + kernel_size, + conv_stride, + first_pool_size, + first_pool_stride, + block_sizes, + block_strides, + resnet_version=DEFAULT_VERSION, + data_format=None, + dtype=DEFAULT_DTYPE, + ): + """Creates a model for classifying an image. + + Args: + resnet_size: A single integer for the size of the ResNet model. + bottleneck: Use regular blocks or bottleneck blocks. + num_classes: The number of classes used as labels. + num_filters: The number of filters to use for the first block layer + of the model. This number is then doubled for each subsequent block + layer. + kernel_size: The kernel size to use for convolution. + conv_stride: stride size for the initial convolutional layer + first_pool_size: Pool size to be used for the first pooling layer. + If none, the first pooling layer is skipped. + first_pool_stride: stride size for the first pooling layer. Not used + if first_pool_size is None. + block_sizes: A list containing n values, where n is the number of sets of + block layers desired. Each value should be the number of blocks in the + i-th set. + block_strides: List of integers representing the desired stride size for + each of the sets of block layers. Should be same length as block_sizes. + resnet_version: Integer representing which version of the ResNet network + to use. See README for details. Valid values: [1, 2] + data_format: Input format ('channels_last', 'channels_first', or None). + If set to None, the format is dependent on whether a GPU is available. + dtype: The TensorFlow dtype to use for calculations. If not specified + tf.float32 is used. + + Raises: + ValueError: if invalid version is selected. + """ + self.resnet_size = resnet_size + + if not data_format: + data_format = ( + "channels_first" if tf.test.is_built_with_cuda() else "channels_last" + ) + + self.resnet_version = resnet_version + if resnet_version not in (1, 2): + raise ValueError( + "Resnet version should be 1 or 2. See README for citations." + ) + + self.bottleneck = bottleneck + if bottleneck: + if resnet_version == 1: + self.block_fn = _bottleneck_block_v1 + else: + self.block_fn = _bottleneck_block_v2 + else: + if resnet_version == 1: + self.block_fn = _building_block_v1 + else: + self.block_fn = _building_block_v2 + + if dtype not in ALLOWED_TYPES: + raise ValueError("dtype must be one of: {}".format(ALLOWED_TYPES)) + + self.data_format = data_format + self.num_classes = num_classes + self.num_filters = num_filters + self.kernel_size = kernel_size + self.conv_stride = conv_stride + self.first_pool_size = first_pool_size + self.first_pool_stride = first_pool_stride + self.block_sizes = block_sizes + self.block_strides = block_strides + self.dtype = dtype + self.pre_activation = resnet_version == 2 + + def _custom_dtype_getter( + self, getter, name, shape=None, dtype=DEFAULT_DTYPE, *args, **kwargs + ): + """Creates variables in fp32, then casts to fp16 if necessary. + + This function is a custom getter. A custom getter is a function with the + same signature as tf.get_variable, except it has an additional getter + parameter. Custom getters can be passed as the `custom_getter` parameter of + tf.variable_scope. Then, tf.get_variable will call the custom getter, + instead of directly getting a variable itself. This can be used to change + the types of variables that are retrieved with tf.get_variable. + The `getter` parameter is the underlying variable getter, that would have + been called if no custom getter was used. Custom getters typically get a + variable with `getter`, then modify it in some way. + + This custom getter will create an fp32 variable. If a low precision + (e.g. float16) variable was requested it will then cast the variable to the + requested dtype. The reason we do not directly create variables in low + precision dtypes is that applying small gradients to such variables may + cause the variable not to change. + + Args: + getter: The underlying variable getter, that has the same signature as + tf.get_variable and returns a variable. + name: The name of the variable to get. + shape: The shape of the variable to get. + dtype: The dtype of the variable to get. Note that if this is a low + precision dtype, the variable will be created as a tf.float32 variable, + then cast to the appropriate dtype + *args: Additional arguments to pass unmodified to getter. + **kwargs: Additional keyword arguments to pass unmodified to getter. + + Returns: + A variable which is cast to fp16 if necessary. + """ + + if dtype in CASTABLE_TYPES: + var = getter(name, shape, tf.float32, *args, **kwargs) + return tf.cast(var, dtype=dtype, name=name + "_cast") + else: + return getter(name, shape, dtype, *args, **kwargs) + + def _model_variable_scope(self): + """Returns a variable scope that the model should be created under. + + If self.dtype is a castable type, model variable will be created in fp32 + then cast to self.dtype before being used. + + Returns: + A variable scope for the model. + """ + + return tf.variable_scope( + "resnet_model", custom_getter=self._custom_dtype_getter + ) + + def __call__(self, inputs, training): + """Add operations to classify a batch of input images. + + Args: + inputs: A Tensor representing a batch of input images. + training: A boolean. Set to True to add operations required only when + training the classifier. + + Returns: + A logits Tensor with shape [, self.num_classes]. + """ + + with self._model_variable_scope(): + if self.data_format == "channels_first": + # Convert the inputs from channels_last (NHWC) to channels_first (NCHW). + # This provides a large performance boost on GPU. See + # https://www.tensorflow.org/performance/performance_guide#data_formats + inputs = tf.transpose(a=inputs, perm=[0, 3, 1, 2]) + + inputs = conv2d_fixed_padding( + inputs=inputs, + filters=self.num_filters, + kernel_size=self.kernel_size, + strides=self.conv_stride, + data_format=self.data_format, + ) + inputs = tf.identity(inputs, "initial_conv") + + # We do not include batch normalization or activation functions in V2 + # for the initial conv1 because the first ResNet unit will perform these + # for both the shortcut and non-shortcut paths as part of the first + # block's projection. Cf. Appendix of [2]. + if self.resnet_version == 1: + inputs = batch_norm(inputs, training, self.data_format) + inputs = tf.nn.relu(inputs) + + if self.first_pool_size: + inputs = tf.layers.max_pooling2d( + inputs=inputs, + pool_size=self.first_pool_size, + strides=self.first_pool_stride, + padding="SAME", + data_format=self.data_format, + ) + inputs = tf.identity(inputs, "initial_max_pool") + + for i, num_blocks in enumerate(self.block_sizes): + num_filters = self.num_filters * (2 ** i) + inputs = block_layer( + inputs=inputs, + filters=num_filters, + bottleneck=self.bottleneck, + block_fn=self.block_fn, + blocks=num_blocks, + strides=self.block_strides[i], + training=training, + name="block_layer{}".format(i + 1), + data_format=self.data_format, + ) + + # Only apply the BN and ReLU for model that does pre_activation in each + # building/bottleneck block, eg resnet V2. + if self.pre_activation: + inputs = batch_norm(inputs, training, self.data_format) + inputs = tf.nn.relu(inputs) + + axes = [2, 3] if self.data_format == "channels_first" else [1, 2] + inputs = tf.nn.avg_pool( + inputs, ksize=[1, 7, 7, 1], strides=[1, 1, 1, 1], padding="VALID" + ) + inputs = tf.identity(inputs, "final_reduce_mean") + inputs = tf.squeeze(inputs, axes) + inputs = tf.layers.dense( + inputs=inputs, + units=self.num_classes, + kernel_initializer=tf.constant_initializer(value=0.001), + ) + + inputs = tf.identity(inputs, "final_dense") + return inputs diff --git a/Athos/Networks/SecureNNBenchmarks/NetworkA.py b/Athos/Networks/SecureNNBenchmarks/NetworkA.py index aa912249..aa2562a5 100644 --- a/Athos/Networks/SecureNNBenchmarks/NetworkA.py +++ b/Athos/Networks/SecureNNBenchmarks/NetworkA.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" # # This is the network A from SecureNN written in Tensorflow. @@ -31,65 +31,69 @@ import os, sys import numpy as np import tensorflow as tf -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "TFCompiler")) import DumpTFMtData useRELUActivation = False + def weight_variable(shape): - """weight_variable generates a weight variable of a given shape.""" - # initial = tf.truncated_normal(shape, stddev=0.1) - initial = tf.constant(0.1, shape=shape) - return tf.Variable(initial) + """weight_variable generates a weight variable of a given shape.""" + # initial = tf.truncated_normal(shape, stddev=0.1) + initial = tf.constant(0.1, shape=shape) + return tf.Variable(initial) + def bias_variable(shape): - """bias_variable generates a bias variable of a given shape.""" - initial = tf.constant(0.1, shape=shape) - return tf.Variable(initial) + """bias_variable generates a bias variable of a given shape.""" + initial = tf.constant(0.1, shape=shape) + return tf.Variable(initial) + x = tf.placeholder(tf.float32, [None, 784]) -#fc1 -with tf.name_scope('fc1'): - w_fc1 = weight_variable([784,128]) - b_fc1 = bias_variable([128]) - outp1 = tf.matmul(x, w_fc1) + b_fc1 +# fc1 +with tf.name_scope("fc1"): + w_fc1 = weight_variable([784, 128]) + b_fc1 = bias_variable([128]) + outp1 = tf.matmul(x, w_fc1) + b_fc1 if useRELUActivation: - actv1 = tf.nn.relu(outp1) + actv1 = tf.nn.relu(outp1) else: - actv1 = tf.square(outp1) + actv1 = tf.square(outp1) -with tf.name_scope('fc2'): - w_fc2 = weight_variable([128,128]) - b_fc2 = bias_variable([128]) - outp2 = tf.matmul(actv1, w_fc2) + b_fc2 +with tf.name_scope("fc2"): + w_fc2 = weight_variable([128, 128]) + b_fc2 = bias_variable([128]) + outp2 = tf.matmul(actv1, w_fc2) + b_fc2 if useRELUActivation: - actv2 = tf.nn.relu(outp2) + actv2 = tf.nn.relu(outp2) else: - actv2 = tf.square(outp2) + actv2 = tf.square(outp2) -with tf.name_scope('fc3'): - w_fc3 = weight_variable([128,10]) - b_fc3 = bias_variable([10]) - outp3 = tf.matmul(actv2, w_fc3) + b_fc3 +with tf.name_scope("fc3"): + w_fc3 = weight_variable([128, 10]) + b_fc3 = bias_variable([10]) + outp3 = tf.matmul(actv2, w_fc3) + b_fc3 -finalOut = tf.argmax(outp3,1) +finalOut = tf.argmax(outp3, 1) with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - imgData = [[0.02]*784] - feed_dict = {x: imgData} - - pred = sess.run(finalOut, feed_dict=feed_dict) - print(pred) - - output_tensor = None - gg = tf.get_default_graph() - for node in gg.as_graph_def().node: - if node.name == 'ArgMax': - output_tensor = gg.get_operation_by_name(node.name).outputs[0] - optimized_graph_def = DumpTFMtData.save_graph_metadata(output_tensor, sess, feed_dict) - - + sess.run(tf.global_variables_initializer()) + imgData = [[0.02] * 784] + feed_dict = {x: imgData} + + pred = sess.run(finalOut, feed_dict=feed_dict) + print(pred) + + output_tensor = None + gg = tf.get_default_graph() + for node in gg.as_graph_def().node: + if node.name == "ArgMax": + output_tensor = gg.get_operation_by_name(node.name).outputs[0] + optimized_graph_def = DumpTFMtData.save_graph_metadata( + output_tensor, sess, feed_dict + ) diff --git a/Athos/Networks/SecureNNBenchmarks/NetworkB.py b/Athos/Networks/SecureNNBenchmarks/NetworkB.py index 341888aa..9e4e1af6 100644 --- a/Athos/Networks/SecureNNBenchmarks/NetworkB.py +++ b/Athos/Networks/SecureNNBenchmarks/NetworkB.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" # # This is the network B from SecureNN written in Tensorflow. @@ -30,58 +30,70 @@ import os, sys import numpy as np import tensorflow as tf -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "TFCompiler")) import DumpTFMtData + def weight_variable(shape): - """weight_variable generates a weight variable of a given shape.""" - # initial = tf.truncated_normal(shape, stddev=0.1) - initial = tf.constant(0.01, shape=shape) - return tf.Variable(initial) + """weight_variable generates a weight variable of a given shape.""" + # initial = tf.truncated_normal(shape, stddev=0.1) + initial = tf.constant(0.01, shape=shape) + return tf.Variable(initial) + def bias_variable(shape): - """bias_variable generates a bias variable of a given shape.""" - initial = tf.constant(0.01, shape=shape) - return tf.Variable(initial) + """bias_variable generates a bias variable of a given shape.""" + initial = tf.constant(0.01, shape=shape) + return tf.Variable(initial) -x = tf.placeholder(tf.float32, [None,784]) -#conv1 -w_conv1 = weight_variable([5,5,1,16]) -conv1 = tf.nn.conv2d(tf.reshape(x, [-1,28,28,1]), w_conv1, strides=[1,1,1,1], padding='VALID') +x = tf.placeholder(tf.float32, [None, 784]) + +# conv1 +w_conv1 = weight_variable([5, 5, 1, 16]) +conv1 = tf.nn.conv2d( + tf.reshape(x, [-1, 28, 28, 1]), w_conv1, strides=[1, 1, 1, 1], padding="VALID" +) relu1 = tf.nn.relu(conv1) -maxpool1 = tf.nn.max_pool(relu1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID') +maxpool1 = tf.nn.max_pool( + relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID" +) -#conv2 -w_conv2 = weight_variable([5,5,16,16]) -conv2 = tf.nn.conv2d(maxpool1, w_conv2, strides=[1,1,1,1], padding='VALID') +# conv2 +w_conv2 = weight_variable([5, 5, 16, 16]) +conv2 = tf.nn.conv2d(maxpool1, w_conv2, strides=[1, 1, 1, 1], padding="VALID") relu2 = tf.nn.relu(conv2) -maxpool2 = tf.nn.max_pool(relu2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID') +maxpool2 = tf.nn.max_pool( + relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID" +) -#fc1 -w_fc1 = weight_variable([256,100]) +# fc1 +w_fc1 = weight_variable([256, 100]) b_fc1 = bias_variable([100]) -fc1 = tf.matmul(tf.reshape(maxpool2, [-1,256]), w_fc1) + b_fc1 +fc1 = tf.matmul(tf.reshape(maxpool2, [-1, 256]), w_fc1) + b_fc1 relu3 = tf.nn.relu(fc1) -#fc2 -w_fc2 = weight_variable([100,10]) +# fc2 +w_fc2 = weight_variable([100, 10]) b_fc2 = bias_variable([10]) fc2 = tf.matmul(relu3, w_fc2) + b_fc2 finalOut = tf.argmax(fc2, 1) with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - imgData = [[0.02]*784] - feed_dict = {x: imgData} - - pred = sess.run(finalOut, feed_dict=feed_dict) - print(pred) - - output_tensor = None - gg = tf.get_default_graph() - for node in gg.as_graph_def().node: - if node.name == 'ArgMax': - output_tensor = gg.get_operation_by_name(node.name).outputs[0] - optimized_graph_def = DumpTFMtData.save_graph_metadata(output_tensor, sess, feed_dict) + sess.run(tf.global_variables_initializer()) + imgData = [[0.02] * 784] + feed_dict = {x: imgData} + + pred = sess.run(finalOut, feed_dict=feed_dict) + print(pred) + + output_tensor = None + gg = tf.get_default_graph() + for node in gg.as_graph_def().node: + if node.name == "ArgMax": + output_tensor = gg.get_operation_by_name(node.name).outputs[0] + optimized_graph_def = DumpTFMtData.save_graph_metadata( + output_tensor, sess, feed_dict + ) diff --git a/Athos/Networks/SecureNNBenchmarks/NetworkC.py b/Athos/Networks/SecureNNBenchmarks/NetworkC.py index 43df4987..dd23f969 100644 --- a/Athos/Networks/SecureNNBenchmarks/NetworkC.py +++ b/Athos/Networks/SecureNNBenchmarks/NetworkC.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" # # This is the network C from SecureNN written in Tensorflow. @@ -31,58 +31,70 @@ import os, sys import numpy as np import tensorflow as tf -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "TFCompiler")) import DumpTFMtData + def weight_variable(shape): - """weight_variable generates a weight variable of a given shape.""" - # initial = tf.truncated_normal(shape, stddev=0.1) - initial = tf.constant(0.01, shape=shape) - return tf.Variable(initial) + """weight_variable generates a weight variable of a given shape.""" + # initial = tf.truncated_normal(shape, stddev=0.1) + initial = tf.constant(0.01, shape=shape) + return tf.Variable(initial) + def bias_variable(shape): - """bias_variable generates a bias variable of a given shape.""" - initial = tf.constant(0.01, shape=shape) - return tf.Variable(initial) + """bias_variable generates a bias variable of a given shape.""" + initial = tf.constant(0.01, shape=shape) + return tf.Variable(initial) -x = tf.placeholder(tf.float32, [None,784]) -#conv1 -w_conv1 = weight_variable([5,5,1,20]) -conv1 = tf.nn.conv2d(tf.reshape(x, [-1,28,28,1]), w_conv1, strides=[1,1,1,1], padding='VALID') +x = tf.placeholder(tf.float32, [None, 784]) + +# conv1 +w_conv1 = weight_variable([5, 5, 1, 20]) +conv1 = tf.nn.conv2d( + tf.reshape(x, [-1, 28, 28, 1]), w_conv1, strides=[1, 1, 1, 1], padding="VALID" +) relu1 = tf.nn.relu(conv1) -maxpool1 = tf.nn.max_pool(relu1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID') +maxpool1 = tf.nn.max_pool( + relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID" +) -#conv2 -w_conv2 = weight_variable([5,5,20,50]) -conv2 = tf.nn.conv2d(maxpool1, w_conv2, strides=[1,1,1,1], padding='VALID') +# conv2 +w_conv2 = weight_variable([5, 5, 20, 50]) +conv2 = tf.nn.conv2d(maxpool1, w_conv2, strides=[1, 1, 1, 1], padding="VALID") relu2 = tf.nn.relu(conv2) -maxpool2 = tf.nn.max_pool(relu2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID') +maxpool2 = tf.nn.max_pool( + relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID" +) -#fc1 -w_fc1 = weight_variable([800,500]) +# fc1 +w_fc1 = weight_variable([800, 500]) b_fc1 = bias_variable([500]) -fc1 = tf.matmul(tf.reshape(maxpool2, [-1,800]), w_fc1) + b_fc1 +fc1 = tf.matmul(tf.reshape(maxpool2, [-1, 800]), w_fc1) + b_fc1 relu3 = tf.nn.relu(fc1) -#fc2 -w_fc2 = weight_variable([500,10]) +# fc2 +w_fc2 = weight_variable([500, 10]) b_fc2 = bias_variable([10]) fc2 = tf.matmul(relu3, w_fc2) + b_fc2 finalOut = tf.argmax(fc2, 1) with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - imgData = [[0.02]*784] - feed_dict = {x: imgData} - - pred = sess.run(finalOut, feed_dict=feed_dict) - print(pred) - - output_tensor = None - gg = tf.get_default_graph() - for node in gg.as_graph_def().node: - if node.name == 'ArgMax': - output_tensor = gg.get_operation_by_name(node.name).outputs[0] - optimized_graph_def = DumpTFMtData.save_graph_metadata(output_tensor, sess, feed_dict) + sess.run(tf.global_variables_initializer()) + imgData = [[0.02] * 784] + feed_dict = {x: imgData} + + pred = sess.run(finalOut, feed_dict=feed_dict) + print(pred) + + output_tensor = None + gg = tf.get_default_graph() + for node in gg.as_graph_def().node: + if node.name == "ArgMax": + output_tensor = gg.get_operation_by_name(node.name).outputs[0] + optimized_graph_def = DumpTFMtData.save_graph_metadata( + output_tensor, sess, feed_dict + ) diff --git a/Athos/Networks/SecureNNBenchmarks/NetworkD.py b/Athos/Networks/SecureNNBenchmarks/NetworkD.py index b099a974..07e5d8aa 100644 --- a/Athos/Networks/SecureNNBenchmarks/NetworkD.py +++ b/Athos/Networks/SecureNNBenchmarks/NetworkD.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" # # This is the network D from SecureNN written in Tensorflow. @@ -30,51 +30,59 @@ import os, sys import numpy as np import tensorflow as tf -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "TFCompiler")) import DumpTFMtData + def weight_variable(shape): - """weight_variable generates a weight variable of a given shape.""" - # initial = tf.truncated_normal(shape, stddev=0.1) - initial = tf.constant(0.01, shape=shape) - return tf.Variable(initial) + """weight_variable generates a weight variable of a given shape.""" + # initial = tf.truncated_normal(shape, stddev=0.1) + initial = tf.constant(0.01, shape=shape) + return tf.Variable(initial) + def bias_variable(shape): - """bias_variable generates a bias variable of a given shape.""" - initial = tf.constant(0.01, shape=shape) - return tf.Variable(initial) + """bias_variable generates a bias variable of a given shape.""" + initial = tf.constant(0.01, shape=shape) + return tf.Variable(initial) -x = tf.placeholder(tf.float32, [None,784]) -#conv1 -w_conv1 = weight_variable([5,5,1,5]) -conv1 = tf.nn.conv2d(tf.reshape(x, [-1,28,28,1]), w_conv1, strides=[1,2,2,1], padding='SAME') +x = tf.placeholder(tf.float32, [None, 784]) + +# conv1 +w_conv1 = weight_variable([5, 5, 1, 5]) +conv1 = tf.nn.conv2d( + tf.reshape(x, [-1, 28, 28, 1]), w_conv1, strides=[1, 2, 2, 1], padding="SAME" +) relu1 = tf.nn.relu(conv1) -#fc1 -w_fc1 = weight_variable([980,100]) +# fc1 +w_fc1 = weight_variable([980, 100]) b_fc1 = bias_variable([100]) -fc1 = tf.matmul(tf.reshape(relu1, [-1,980]), w_fc1) + b_fc1 +fc1 = tf.matmul(tf.reshape(relu1, [-1, 980]), w_fc1) + b_fc1 relu3 = tf.nn.relu(fc1) -#fc2 -w_fc2 = weight_variable([100,10]) +# fc2 +w_fc2 = weight_variable([100, 10]) b_fc2 = bias_variable([10]) fc2 = tf.matmul(relu3, w_fc2) + b_fc2 finalOut = tf.argmax(fc2, 1) with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - imgData = [[0.02]*784] - feed_dict = {x: imgData} - - pred = sess.run(finalOut, feed_dict=feed_dict) - print(pred) - - output_tensor = None - gg = tf.get_default_graph() - for node in gg.as_graph_def().node: - if node.name == 'ArgMax': - output_tensor = gg.get_operation_by_name(node.name).outputs[0] - optimized_graph_def = DumpTFMtData.save_graph_metadata(output_tensor, sess, feed_dict) + sess.run(tf.global_variables_initializer()) + imgData = [[0.02] * 784] + feed_dict = {x: imgData} + + pred = sess.run(finalOut, feed_dict=feed_dict) + print(pred) + + output_tensor = None + gg = tf.get_default_graph() + for node in gg.as_graph_def().node: + if node.name == "ArgMax": + output_tensor = gg.get_operation_by_name(node.name).outputs[0] + optimized_graph_def = DumpTFMtData.save_graph_metadata( + output_tensor, sess, feed_dict + ) diff --git a/Athos/Networks/SqueezeNetCIFAR10/Squeezenet_model.py b/Athos/Networks/SqueezeNetCIFAR10/Squeezenet_model.py index 8e59916a..aacd8977 100644 --- a/Athos/Networks/SqueezeNetCIFAR10/Squeezenet_model.py +++ b/Athos/Networks/SqueezeNetCIFAR10/Squeezenet_model.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -25,7 +25,7 @@ were taken from https://github.com/kaizouman/tensorsandbox/tree/master/cifar10/models/squeeze ** -''' +""" from __future__ import absolute_import from __future__ import division @@ -42,589 +42,699 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) from tensorflow.python.util import deprecation + deprecation._PRINT_DEPRECATION_WARNINGS = False try: - from tensorflow.python.util import module_wrapper as deprecation + from tensorflow.python.util import module_wrapper as deprecation except ImportError: - from tensorflow.python.util import deprecation_wrapper as deprecation + from tensorflow.python.util import deprecation_wrapper as deprecation deprecation._PER_MODULE_WARNING_LIMIT = 0 -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "TFCompiler")) import DumpTFMtData from argparse import ArgumentParser + class SqueezeNet1Orig: - def __init__(self): - self.all_weights = [] - - def inference(self, images): - # conv1 - conv1 = self.conv_layer(images, - size=3, - filters=64, - stride=1, - decay=False, - name='conv1') - - # pool1 - pool1 = self.pool_layer(conv1, - size=3, - stride=2, - name='pool1') - - # fire2 - fire2 = self.fire_layer(pool1, 32, 64, 64, decay=False, name='fire2') - - # fire3 - fire3 = self.fire_layer(fire2, 32, 64, 64, decay=False, name='fire3') - - # pool2 - pool2 = self.pool_layer(fire3, - size=3, - stride=2, - name='pool2') - - # fire4 - fire4 = self.fire_layer(pool2, 32, 128, 128, decay=False, name='fire4') - - # fire5 - fire5 = self.fire_layer(fire4, 32, 128, 128, decay=False, name='fire5') - - # Final squeeze to get ten classes - conv2 = self.conv_layer(fire5, - size=1, - filters=10, - stride=1, - decay=False, - name='squeeze') - - # Average pooling on spatial dimensions - predictions = self.avg_layer(conv2, name='avg_pool') - - return predictions - - def pool_layer(self, inputs, size, stride, name): - with tf.variable_scope(name) as scope: - outputs = tf.nn.max_pool(inputs, - ksize=[1,size,size,1], - strides=[1,stride,stride,1], - padding='SAME', - name=name) - return outputs - - def fire_layer(self, inputs, s1x1, e1x1, e3x3, name, decay=False): - with tf.variable_scope(name) as scope: - # Squeeze sub-layer - squeezed_inputs = self.conv_layer(inputs, - size=1, - filters=s1x1, - stride=1, - decay=decay, - name='s1x1') - - # Expand 1x1 sub-layer - e1x1_outputs = self.conv_layer(squeezed_inputs, - size=1, - filters=e1x1, - stride=1, - decay=decay, - name='e1x1') - - # Expand 3x3 sub-layer - e3x3_outputs = self.conv_layer(squeezed_inputs, - size=3, - filters=e3x3, - stride=1, - decay=decay, - name='e3x3') - - # Concatenate outputs along the last dimension (channel) - return tf.concat([e1x1_outputs, e3x3_outputs], 3) - - def avg_layer(self, inputs, name): - w = inputs.get_shape().as_list()[1] - h = inputs.get_shape().as_list()[2] - c = inputs.get_shape().as_list()[3] - with tf.variable_scope(name) as scope: - # Use current spatial dimensions as Kernel size to produce a scalar - avg = tf.nn.avg_pool(inputs, - ksize=[1,w,h,1], - strides=[1,1,1,1], - padding='VALID', - name=scope.name) - # Reshape output to remove spatial dimensions reduced to one - return tf.reshape(avg, shape=[-1,c]) - - def conv_layer(self, inputs, size, filters, stride, decay, name): - channels = inputs.shape[3] - shape = [size, size, channels, filters] - with tf.variable_scope(name + '/conv') as scope: - weights = self._get_weights_var('weights', shape=shape, decay=decay) - biases = self.get_cons_variable([filters], 0.0) - conv = tf.nn.conv2d(inputs, - weights, - strides=[1,stride,stride,1], - padding='SAME') - - pre_activation = tf.nn.bias_add(conv, biases) - - outputs= tf.nn.relu(pre_activation, name=scope.name) - - return outputs - - def get_cons_variable(self, shape, val): - initial = tf.constant(val, shape=shape) - temp = tf.Variable(initial) - self.all_weights.append(temp) - return temp - - def _get_weights_var(self, name, shape, decay=False): - """Helper to create an initialized Variable with weight decay. - - The Variable is initialized using a normal distribution whose variance - is provided by the xavier formula (ie inversely proportional to the number - of inputs) - - Args: - name: name of the tensor variable - shape: the tensor shape - decay: a boolean indicating if we apply decay to the tensor weights - using a regularization loss - - Returns: - Variable Tensor - """ - # Declare an initializer for this variable - initializer = tf.contrib.layers.xavier_initializer(uniform=False,dtype=tf.float32) - # Declare variable (it is trainable by default) - var = tf.get_variable(name=name, - shape=shape, - initializer=initializer, - dtype=tf.float32) - if decay: - # We apply a weight decay to this tensor var that is equal to the - # model weight decay divided by the tensor size - weight_decay = self.wd - for x in shape: - weight_decay /= x - # Weight loss is L2 loss multiplied by weight decay - weight_loss = tf.multiply(tf.nn.l2_loss(var), - weight_decay, - name='weight_loss') - # Add weight loss for this variable to the global losses collection - tf.add_to_collection('losses', weight_loss) - - self.all_weights.append(var) - return var + def __init__(self): + self.all_weights = [] + + def inference(self, images): + # conv1 + conv1 = self.conv_layer( + images, size=3, filters=64, stride=1, decay=False, name="conv1" + ) + + # pool1 + pool1 = self.pool_layer(conv1, size=3, stride=2, name="pool1") + + # fire2 + fire2 = self.fire_layer(pool1, 32, 64, 64, decay=False, name="fire2") + + # fire3 + fire3 = self.fire_layer(fire2, 32, 64, 64, decay=False, name="fire3") + + # pool2 + pool2 = self.pool_layer(fire3, size=3, stride=2, name="pool2") + + # fire4 + fire4 = self.fire_layer(pool2, 32, 128, 128, decay=False, name="fire4") + + # fire5 + fire5 = self.fire_layer(fire4, 32, 128, 128, decay=False, name="fire5") + + # Final squeeze to get ten classes + conv2 = self.conv_layer( + fire5, size=1, filters=10, stride=1, decay=False, name="squeeze" + ) + + # Average pooling on spatial dimensions + predictions = self.avg_layer(conv2, name="avg_pool") + + return predictions + + def pool_layer(self, inputs, size, stride, name): + with tf.variable_scope(name) as scope: + outputs = tf.nn.max_pool( + inputs, + ksize=[1, size, size, 1], + strides=[1, stride, stride, 1], + padding="SAME", + name=name, + ) + return outputs + + def fire_layer(self, inputs, s1x1, e1x1, e3x3, name, decay=False): + with tf.variable_scope(name) as scope: + # Squeeze sub-layer + squeezed_inputs = self.conv_layer( + inputs, size=1, filters=s1x1, stride=1, decay=decay, name="s1x1" + ) + + # Expand 1x1 sub-layer + e1x1_outputs = self.conv_layer( + squeezed_inputs, + size=1, + filters=e1x1, + stride=1, + decay=decay, + name="e1x1", + ) + + # Expand 3x3 sub-layer + e3x3_outputs = self.conv_layer( + squeezed_inputs, + size=3, + filters=e3x3, + stride=1, + decay=decay, + name="e3x3", + ) + + # Concatenate outputs along the last dimension (channel) + return tf.concat([e1x1_outputs, e3x3_outputs], 3) + + def avg_layer(self, inputs, name): + w = inputs.get_shape().as_list()[1] + h = inputs.get_shape().as_list()[2] + c = inputs.get_shape().as_list()[3] + with tf.variable_scope(name) as scope: + # Use current spatial dimensions as Kernel size to produce a scalar + avg = tf.nn.avg_pool( + inputs, + ksize=[1, w, h, 1], + strides=[1, 1, 1, 1], + padding="VALID", + name=scope.name, + ) + # Reshape output to remove spatial dimensions reduced to one + return tf.reshape(avg, shape=[-1, c]) + + def conv_layer(self, inputs, size, filters, stride, decay, name): + channels = inputs.shape[3] + shape = [size, size, channels, filters] + with tf.variable_scope(name + "/conv") as scope: + weights = self._get_weights_var("weights", shape=shape, decay=decay) + biases = self.get_cons_variable([filters], 0.0) + conv = tf.nn.conv2d( + inputs, weights, strides=[1, stride, stride, 1], padding="SAME" + ) + + pre_activation = tf.nn.bias_add(conv, biases) + + outputs = tf.nn.relu(pre_activation, name=scope.name) + + return outputs + + def get_cons_variable(self, shape, val): + initial = tf.constant(val, shape=shape) + temp = tf.Variable(initial) + self.all_weights.append(temp) + return temp + + def _get_weights_var(self, name, shape, decay=False): + """Helper to create an initialized Variable with weight decay. + + The Variable is initialized using a normal distribution whose variance + is provided by the xavier formula (ie inversely proportional to the number + of inputs) + + Args: + name: name of the tensor variable + shape: the tensor shape + decay: a boolean indicating if we apply decay to the tensor weights + using a regularization loss + + Returns: + Variable Tensor + """ + # Declare an initializer for this variable + initializer = tf.contrib.layers.xavier_initializer( + uniform=False, dtype=tf.float32 + ) + # Declare variable (it is trainable by default) + var = tf.get_variable( + name=name, shape=shape, initializer=initializer, dtype=tf.float32 + ) + if decay: + # We apply a weight decay to this tensor var that is equal to the + # model weight decay divided by the tensor size + weight_decay = self.wd + for x in shape: + weight_decay /= x + # Weight loss is L2 loss multiplied by weight decay + weight_loss = tf.multiply( + tf.nn.l2_loss(var), weight_decay, name="weight_loss" + ) + # Add weight loss for this variable to the global losses collection + tf.add_to_collection("losses", weight_loss) + + self.all_weights.append(var) + return var + class SqueezeNet1: - def __init__(self, use_cons_init): - self.all_weights = [] - self.debug_weights = [] - self.use_cons_init = use_cons_init - - def inference(self, images): - # conv1 - conv1 = self.conv_layer(images, - size=3, - filters=64, - stride=1, - decay=False, - name='conv1') - - # pool1 - pool1 = self.pool_layer(conv1, - size=3, - stride=2, - name='pool1') - - # fire2 - fire2 = self.fire_layer(pool1, 32, 64, 64, decay=False, name='fire2') - - # fire3 - fire3 = self.fire_layer(fire2, 32, 64, 64, decay=False, name='fire3') - - # pool2 - pool2 = self.pool_layer(fire3, - size=3, - stride=2, - name='pool2') - - # fire4 - fire4 = self.fire_layer(pool2, 32, 128, 128, decay=False, name='fire4') - - # fire5 - fire5 = self.fire_layer(fire4, 32, 128, 128, decay=False, name='fire5') - - # Final squeeze to get ten classes - conv2 = self.conv_layer(fire5, - size=1, - filters=10, - stride=1, - decay=False, - name='squeeze') - - # Average pooling on spatial dimensions - predictions = self.avg_layer(conv2, name='avg_pool') - - return predictions - - def pool_layer(self, inputs, size, stride, name): - with tf.variable_scope(name) as scope: - outputs = tf.nn.max_pool(inputs, - ksize=[1,size,size,1], - strides=[1,stride,stride,1], - padding='SAME', - name=name) - return outputs - - def fire_layer(self, inputs, s1x1, e1x1, e3x3, name, decay=False): - with tf.variable_scope(name) as scope: - # Squeeze sub-layer - squeezed_inputs = self.conv_layer(inputs, - size=1, - filters=s1x1, - stride=1, - decay=decay, - name='s1x1') - - # Expand 1x1 sub-layer - e1x1_outputs = self.conv_layer(squeezed_inputs, - size=1, - filters=e1x1, - stride=1, - decay=decay, - name='e1x1') - - # Expand 3x3 sub-layer - e3x3_outputs = self.conv_layer(squeezed_inputs, - size=3, - filters=e3x3, - stride=1, - decay=decay, - name='e3x3') - - # Concatenate outputs along the last dimension (channel) - return tf.concat([e1x1_outputs, e3x3_outputs], 3) - - def avg_layer(self, inputs, name): - w = inputs.get_shape().as_list()[1] - h = inputs.get_shape().as_list()[2] - c = inputs.get_shape().as_list()[3] - with tf.variable_scope(name) as scope: - # Use current spatial dimensions as Kernel size to produce a scalar - avg = tf.nn.avg_pool(inputs, - ksize=[1,w,h,1], - strides=[1,1,1,1], - padding='VALID', - name=scope.name) - # Reshape output to remove spatial dimensions reduced to one - return tf.reshape(avg, shape=[-1,c]) - - def conv_layer(self, inputs, size, filters, stride, decay, name): - channels = inputs.shape[3] - shape = [size, size, channels, filters] - with tf.variable_scope(name + '/conv') as scope: - # For getting performance numbers, don't need to use the actual activations - just use constant activations - if self.use_cons_init: - weights = self.get_cons_variable(shape, 0.01) - else: - weights = self._get_weights_var('weights', shape=shape, decay=decay) - - biases = self.get_cons_variable([filters], 0.0) - conv = tf.nn.conv2d(inputs, - weights, - strides=[1,stride,stride,1], - padding='SAME') - - pre_activation = tf.nn.bias_add(conv, biases) - outputs= tf.nn.relu(pre_activation, name=scope.name) - - return outputs - - def get_cons_variable(self, shape, val): - initial = tf.constant(val, shape=shape) - temp = tf.Variable(initial) - self.all_weights.append(temp) - return temp - - def _get_weights_var(self, name, shape, decay=False): - """Helper to create an initialized Variable with weight decay. - - The Variable is initialized using a normal distribution whose variance - is provided by the xavier formula (ie inversely proportional to the number - of inputs) - - Args: - name: name of the tensor variable - shape: the tensor shape - decay: a boolean indicating if we apply decay to the tensor weights - using a regularization loss - - Returns: - Variable Tensor - """ - # Declare an initializer for this variable - initializer = tf.contrib.layers.xavier_initializer(uniform=False,dtype=tf.float32) - # Declare variable (it is trainable by default) - var = tf.get_variable(name=name, - shape=shape, - initializer=initializer, - dtype=tf.float32) - if decay: - # We apply a weight decay to this tensor var that is equal to the - # model weight decay divided by the tensor size - weight_decay = self.wd - for x in shape: - weight_decay /= x - # Weight loss is L2 loss multiplied by weight decay - weight_loss = tf.multiply(tf.nn.l2_loss(var), - weight_decay, - name='weight_loss') - # Add weight loss for this variable to the global losses collection - tf.add_to_collection('losses', weight_loss) - - self.all_weights.append(var) - return var + def __init__(self, use_cons_init): + self.all_weights = [] + self.debug_weights = [] + self.use_cons_init = use_cons_init + + def inference(self, images): + # conv1 + conv1 = self.conv_layer( + images, size=3, filters=64, stride=1, decay=False, name="conv1" + ) + + # pool1 + pool1 = self.pool_layer(conv1, size=3, stride=2, name="pool1") + + # fire2 + fire2 = self.fire_layer(pool1, 32, 64, 64, decay=False, name="fire2") + + # fire3 + fire3 = self.fire_layer(fire2, 32, 64, 64, decay=False, name="fire3") + + # pool2 + pool2 = self.pool_layer(fire3, size=3, stride=2, name="pool2") + + # fire4 + fire4 = self.fire_layer(pool2, 32, 128, 128, decay=False, name="fire4") + + # fire5 + fire5 = self.fire_layer(fire4, 32, 128, 128, decay=False, name="fire5") + + # Final squeeze to get ten classes + conv2 = self.conv_layer( + fire5, size=1, filters=10, stride=1, decay=False, name="squeeze" + ) + + # Average pooling on spatial dimensions + predictions = self.avg_layer(conv2, name="avg_pool") + + return predictions + + def pool_layer(self, inputs, size, stride, name): + with tf.variable_scope(name) as scope: + outputs = tf.nn.max_pool( + inputs, + ksize=[1, size, size, 1], + strides=[1, stride, stride, 1], + padding="SAME", + name=name, + ) + return outputs + + def fire_layer(self, inputs, s1x1, e1x1, e3x3, name, decay=False): + with tf.variable_scope(name) as scope: + # Squeeze sub-layer + squeezed_inputs = self.conv_layer( + inputs, size=1, filters=s1x1, stride=1, decay=decay, name="s1x1" + ) + + # Expand 1x1 sub-layer + e1x1_outputs = self.conv_layer( + squeezed_inputs, + size=1, + filters=e1x1, + stride=1, + decay=decay, + name="e1x1", + ) + + # Expand 3x3 sub-layer + e3x3_outputs = self.conv_layer( + squeezed_inputs, + size=3, + filters=e3x3, + stride=1, + decay=decay, + name="e3x3", + ) + + # Concatenate outputs along the last dimension (channel) + return tf.concat([e1x1_outputs, e3x3_outputs], 3) + + def avg_layer(self, inputs, name): + w = inputs.get_shape().as_list()[1] + h = inputs.get_shape().as_list()[2] + c = inputs.get_shape().as_list()[3] + with tf.variable_scope(name) as scope: + # Use current spatial dimensions as Kernel size to produce a scalar + avg = tf.nn.avg_pool( + inputs, + ksize=[1, w, h, 1], + strides=[1, 1, 1, 1], + padding="VALID", + name=scope.name, + ) + # Reshape output to remove spatial dimensions reduced to one + return tf.reshape(avg, shape=[-1, c]) + + def conv_layer(self, inputs, size, filters, stride, decay, name): + channels = inputs.shape[3] + shape = [size, size, channels, filters] + with tf.variable_scope(name + "/conv") as scope: + # For getting performance numbers, don't need to use the actual activations - just use constant activations + if self.use_cons_init: + weights = self.get_cons_variable(shape, 0.01) + else: + weights = self._get_weights_var("weights", shape=shape, decay=decay) + + biases = self.get_cons_variable([filters], 0.0) + conv = tf.nn.conv2d( + inputs, weights, strides=[1, stride, stride, 1], padding="SAME" + ) + + pre_activation = tf.nn.bias_add(conv, biases) + outputs = tf.nn.relu(pre_activation, name=scope.name) + + return outputs + + def get_cons_variable(self, shape, val): + initial = tf.constant(val, shape=shape) + temp = tf.Variable(initial) + self.all_weights.append(temp) + return temp + + def _get_weights_var(self, name, shape, decay=False): + """Helper to create an initialized Variable with weight decay. + + The Variable is initialized using a normal distribution whose variance + is provided by the xavier formula (ie inversely proportional to the number + of inputs) + + Args: + name: name of the tensor variable + shape: the tensor shape + decay: a boolean indicating if we apply decay to the tensor weights + using a regularization loss + + Returns: + Variable Tensor + """ + # Declare an initializer for this variable + initializer = tf.contrib.layers.xavier_initializer( + uniform=False, dtype=tf.float32 + ) + # Declare variable (it is trainable by default) + var = tf.get_variable( + name=name, shape=shape, initializer=initializer, dtype=tf.float32 + ) + if decay: + # We apply a weight decay to this tensor var that is equal to the + # model weight decay divided by the tensor size + weight_decay = self.wd + for x in shape: + weight_decay /= x + # Weight loss is L2 loss multiplied by weight decay + weight_loss = tf.multiply( + tf.nn.l2_loss(var), weight_decay, name="weight_loss" + ) + # Add weight loss for this variable to the global losses collection + tf.add_to_collection("losses", weight_loss) + + self.all_weights.append(var) + return var + def train(sqn, save_model_path): - print('Starting train...') - - # Hyper parameters - epochs = 1 - batch_size = 128 - keep_probability = 0.7 - learning_rate = 0.001 - n_batches = 5 #CIFAR10 dataset in the python version has 5 batches - - x = tf.placeholder(tf.float32, shape=(None, 32, 32, 3), name='input_x') - y = tf.placeholder(tf.float32, shape=(None, 10), name='output_y') - keep_prob = tf.placeholder(tf.float32, name='keep_prob') - - logits = sqn.inference(x) - - # Loss and Optimizer - cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y)) - l2Loss = sum(list(map(lambda x: tf.nn.l2_loss(x), sqn.all_weights))) - beta = 1e-5 - cost = cost + beta*l2Loss - optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) - - # Accuracy - correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1)) - accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy') - - valid_features, valid_labels = Util.load_preprocess_validation_data() - testing_features, testing_labels = Util.load_preprocess_testing_data() - - print('Training now...') - with tf.Session() as sess: - # Initializing the variables - sess.run(tf.global_variables_initializer()) - - # Training cycle - for epoch in range(epochs): - # Loop over all batches - for batch_i in range(1, n_batches + 1): - for batch_features, batch_labels in Util.load_preprocess_training_batch(batch_i, batch_size): - sess.run(optimizer, feed_dict={x: batch_features, - y: batch_labels, - keep_prob: keep_probability - }) - - print('Epoch {:>2}, CIFAR-10 Batch {}: '.format(epoch + 1, batch_i), end='') - - # Print stats - loss = sess.run(cost, feed_dict={x: batch_features, y: batch_labels, keep_prob: keep_probability}) - train_acc = sess.run(accuracy, feed_dict={x: batch_features, y: batch_labels, keep_prob: keep_probability}) - valid_acc = sess.run(accuracy, feed_dict={x: valid_features, y: valid_labels, keep_prob: keep_probability}) - testing_acc = sess.run(accuracy, feed_dict={x: testing_features, y: testing_labels, keep_prob: keep_probability}) - print('Loss: {:>10.4f} Train Acc: {:.6f} Validation Accuracy: {:.6f} Testing Acc: {:.6f}'.format(loss, train_acc, valid_acc, testing_acc)) - - if (epoch % 10 == 0): - # Save Model - saver = tf.train.Saver() - save_path = saver.save(sess, save_model_path) - -#outputArgMax should only be used when findAcc is False -def infer(sqn, sess, images, labels, restoreModelPath, findAccOrArgMaxOrPredVal=0, restoreWeights=True, onlysavegraph=False): - assert(findAccOrArgMaxOrPredVal>=0 and findAccOrArgMaxOrPredVal<=2) - if restoreWeights: assert(not(onlysavegraph)) - if onlysavegraph: assert(findAccOrArgMaxOrPredVal==1) - - x = tf.placeholder(tf.float32, shape=(None, 32, 32, 3), name='input_x') - if (not(onlysavegraph)): - y = tf.placeholder(tf.int32, shape=(None, 10), name='output_y') - logits = sqn.inference(x) - - if findAccOrArgMaxOrPredVal==0: - correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1)) - accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy') - elif findAccOrArgMaxOrPredVal==1: - logits = tf.argmax(logits, axis=1) - elif findAccOrArgMaxOrPredVal==2: - pass - else: - assert False - - print("Doing inference on ", len(images), " images.") - feed_dict = {x: images} - if not(onlysavegraph): - feed_dict[y] = labels - - sess.run(tf.global_variables_initializer()) - if onlysavegraph: - output_tensor = None - gg = tf.get_default_graph() - for node in gg.as_graph_def().node: - if node.name == 'ArgMax': - output_tensor = gg.get_operation_by_name(node.name).outputs[0] - optimized_graph_def = DumpTFMtData.save_graph_metadata(output_tensor, sess, feed_dict) - return - - if restoreWeights: - saver = tf.train.Saver(sqn.all_weights) - saver.restore(sess, restoreModelPath) - - print("*************** Starting Prediction****************") - start_time = time.time() - if findAccOrArgMaxOrPredVal==0: - predictions = sess.run([accuracy], feed_dict=feed_dict) - else: - predictions = sess.run([logits], feed_dict=feed_dict) - end_time = time.time() - print("*************** Done Prediction****************") - duration = end_time - start_time - print("Time taken in prediction : ", duration) - - print("Inference result = ", predictions) - with open('tf_pred.float','w+') as f: - f.write(DumpTFMtData.numpy_float_array_to_float_val_str(predictions)) - with open('tf_pred.time','w') as f: - f.write(str(round(duration, 2))) - return predictions + print("Starting train...") + + # Hyper parameters + epochs = 1 + batch_size = 128 + keep_probability = 0.7 + learning_rate = 0.001 + n_batches = 5 # CIFAR10 dataset in the python version has 5 batches + + x = tf.placeholder(tf.float32, shape=(None, 32, 32, 3), name="input_x") + y = tf.placeholder(tf.float32, shape=(None, 10), name="output_y") + keep_prob = tf.placeholder(tf.float32, name="keep_prob") + + logits = sqn.inference(x) + + # Loss and Optimizer + cost = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y) + ) + l2Loss = sum(list(map(lambda x: tf.nn.l2_loss(x), sqn.all_weights))) + beta = 1e-5 + cost = cost + beta * l2Loss + optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) + + # Accuracy + correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1)) + accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name="accuracy") + + valid_features, valid_labels = Util.load_preprocess_validation_data() + testing_features, testing_labels = Util.load_preprocess_testing_data() + + print("Training now...") + with tf.Session() as sess: + # Initializing the variables + sess.run(tf.global_variables_initializer()) + + # Training cycle + for epoch in range(epochs): + # Loop over all batches + for batch_i in range(1, n_batches + 1): + for batch_features, batch_labels in Util.load_preprocess_training_batch( + batch_i, batch_size + ): + sess.run( + optimizer, + feed_dict={ + x: batch_features, + y: batch_labels, + keep_prob: keep_probability, + }, + ) + + print( + "Epoch {:>2}, CIFAR-10 Batch {}: ".format(epoch + 1, batch_i), + end="", + ) + + # Print stats + loss = sess.run( + cost, + feed_dict={ + x: batch_features, + y: batch_labels, + keep_prob: keep_probability, + }, + ) + train_acc = sess.run( + accuracy, + feed_dict={ + x: batch_features, + y: batch_labels, + keep_prob: keep_probability, + }, + ) + valid_acc = sess.run( + accuracy, + feed_dict={ + x: valid_features, + y: valid_labels, + keep_prob: keep_probability, + }, + ) + testing_acc = sess.run( + accuracy, + feed_dict={ + x: testing_features, + y: testing_labels, + keep_prob: keep_probability, + }, + ) + print( + "Loss: {:>10.4f} Train Acc: {:.6f} Validation Accuracy: {:.6f} Testing Acc: {:.6f}".format( + loss, train_acc, valid_acc, testing_acc + ) + ) + + if epoch % 10 == 0: + # Save Model + saver = tf.train.Saver() + save_path = saver.save(sess, save_model_path) + + +# outputArgMax should only be used when findAcc is False +def infer( + sqn, + sess, + images, + labels, + restoreModelPath, + findAccOrArgMaxOrPredVal=0, + restoreWeights=True, + onlysavegraph=False, +): + assert findAccOrArgMaxOrPredVal >= 0 and findAccOrArgMaxOrPredVal <= 2 + if restoreWeights: + assert not (onlysavegraph) + if onlysavegraph: + assert findAccOrArgMaxOrPredVal == 1 + + x = tf.placeholder(tf.float32, shape=(None, 32, 32, 3), name="input_x") + if not (onlysavegraph): + y = tf.placeholder(tf.int32, shape=(None, 10), name="output_y") + logits = sqn.inference(x) + + if findAccOrArgMaxOrPredVal == 0: + correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1)) + accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name="accuracy") + elif findAccOrArgMaxOrPredVal == 1: + logits = tf.argmax(logits, axis=1) + elif findAccOrArgMaxOrPredVal == 2: + pass + else: + assert False + + print("Doing inference on ", len(images), " images.") + feed_dict = {x: images} + if not (onlysavegraph): + feed_dict[y] = labels + + sess.run(tf.global_variables_initializer()) + if onlysavegraph: + output_tensor = None + gg = tf.get_default_graph() + for node in gg.as_graph_def().node: + if node.name == "ArgMax": + output_tensor = gg.get_operation_by_name(node.name).outputs[0] + optimized_graph_def = DumpTFMtData.save_graph_metadata( + output_tensor, sess, feed_dict + ) + return + + if restoreWeights: + saver = tf.train.Saver(sqn.all_weights) + saver.restore(sess, restoreModelPath) + + print("*************** Starting Prediction****************") + start_time = time.time() + if findAccOrArgMaxOrPredVal == 0: + predictions = sess.run([accuracy], feed_dict=feed_dict) + else: + predictions = sess.run([logits], feed_dict=feed_dict) + end_time = time.time() + print("*************** Done Prediction****************") + duration = end_time - start_time + print("Time taken in prediction : ", duration) + + print("Inference result = ", predictions) + with open("tf_pred.float", "w+") as f: + f.write(DumpTFMtData.numpy_float_array_to_float_val_str(predictions)) + with open("tf_pred.time", "w") as f: + f.write(str(round(duration, 2))) + return predictions + def getTrainedWeightsStrForm(sess, evalTensors, scalingFac): - allWeightsStr = '' - finalParameters = map(lambda x : sess.run(x), evalTensors) - for curParameterVal in finalParameters: - for xx in numpy.nditer(curParameterVal, order='C'): - allWeightsStr += (str(int(xx * (1< 1): - inp = sys.argv[1] - if (inp == 'train'): - doTraining = True - elif (inp == 'savegraph'): - findAccOrArgMaxOrPredVal = 1 - restoreWeights = False - onlysavegraph = True - testing_features, testing_labels = Util.get_sample_points(2, 4555, 4556) - elif (inp == 'testSingleTestInp'): - testBatchInpNum = int(sys.argv[2]) - findAccOrArgMaxOrPredVal = 1 - restoreWeights = True - onlysavegraph = False - all_testing_features, all_testing_labels = Util.load_preprocess_testing_data() - testing_features, testing_labels = all_testing_features[testBatchInpNum:testBatchInpNum+1], all_testing_labels[testBatchInpNum:testBatchInpNum+1] - elif (inp == 'testSingleTestInpAndSaveData'): - testBatchInpNum = int(sys.argv[2]) - findAccOrArgMaxOrPredVal = int(sys.argv[3]) - restoreWeights = True - onlysavegraph = False - all_testing_features, all_testing_labels = Util.load_preprocess_testing_data() - testing_features, testing_labels = all_testing_features[testBatchInpNum:testBatchInpNum+1], all_testing_labels[testBatchInpNum:testBatchInpNum+1] - # testing_features, testing_labels = numpy.full((1,32,32,3),0.01), numpy.full((1,10),0.01) - elif (inp == 'savegraphAndDataBatch'): - batchNum = int(sys.argv[2]) - imgStartNum = int(sys.argv[3]) - imgEndNum = int(sys.argv[4]) - findAccOrArgMaxOrPredVal = 1 - restoreWeights = False - onlysavegraph = True - testing_features, testing_labels = Util.get_sample_points(batchNum, imgStartNum, imgEndNum) - elif (inp == 'testBatchInp'): - imgStartNum = int(sys.argv[2]) - imgEndNum = int(sys.argv[3]) - findAccOrArgMaxOrPredVal = 1 - restoreWeights = True - onlysavegraph = False - all_testing_features, all_testing_labels = Util.load_preprocess_testing_data() - testing_features, testing_labels = all_testing_features[imgStartNum:imgEndNum], all_testing_labels[imgStartNum:imgEndNum] - elif (inp == 'findAndSaveCorrectTestImg'): - findAccOrArgMaxOrPredVal = 2 - restoreWeights = True - onlysavegraph = False - testing_features, testing_labels = Util.load_preprocess_testing_data() - testing_features, testing_labels = testing_features[0:100], testing_labels[0:100] - else: - if (inp != ""): - print("WARNING : Given option didn't match any known value.") - testing_features, testing_labels = Util.load_preprocess_testing_data() - - sqn = SqueezeNet1(use_cons_init=onlysavegraph) - if doTraining: - train(sqn, save_model_path) - else: - with tf.Session() as sess: - pred = infer(sqn, sess, testing_features, testing_labels, save_model_path, findAccOrArgMaxOrPredVal=findAccOrArgMaxOrPredVal, - restoreWeights=restoreWeights, - onlysavegraph=onlysavegraph) - if findAccOrArgMaxOrPredVal==1 and not(onlysavegraph): - print("Actual labels = ", testing_labels) - print("ArgMax in actual label : ", numpy.argmax(testing_labels, axis=1)) - - if (inp == 'findAndSaveCorrectTestImg'): - print('Running ' + inp) - print(pred[0].shape) - findAndSaveCorrectTestImg(pred, testing_features, testing_labels, './testPred/CorrectImg/', './testPred/IncorrectImg/', './testPred/TestInputs/', sess, sqn, scalingFac) - - if (inp == 'savegraphAndDataBatch' or inp=='testSingleTestInpAndSaveData'): - imgFileName = "model_input_scale_{}.inp".format(scalingFac) - weightsFileName = "model_weights_scale_{}.inp".format(scalingFac) - for ii,curFeature in enumerate(testing_features): - if ii == 0 : - DumpTFMtData.dumpImageDataInt(curFeature, imgFileName, scalingFac, 'w') - else: - DumpTFMtData.dumpImageDataInt(curFeature, imgFileName, scalingFac, 'a') - DumpTFMtData.dumpTrainedWeightsInt(sess, sqn.all_weights, weightsFileName, scalingFac, 'w') - -if __name__ == '__main__': - main() - + scalingFac = 12 + findAccOrArgMaxOrPredVal = 0 + restoreWeights = True + onlysavegraph = False + save_model_path = "./TrainedModel/model" + doTraining = False + + inp = None + if len(sys.argv) > 1: + inp = sys.argv[1] + if inp == "train": + doTraining = True + elif inp == "savegraph": + findAccOrArgMaxOrPredVal = 1 + restoreWeights = False + onlysavegraph = True + testing_features, testing_labels = Util.get_sample_points(2, 4555, 4556) + elif inp == "testSingleTestInp": + testBatchInpNum = int(sys.argv[2]) + findAccOrArgMaxOrPredVal = 1 + restoreWeights = True + onlysavegraph = False + ( + all_testing_features, + all_testing_labels, + ) = Util.load_preprocess_testing_data() + testing_features, testing_labels = ( + all_testing_features[testBatchInpNum : testBatchInpNum + 1], + all_testing_labels[testBatchInpNum : testBatchInpNum + 1], + ) + elif inp == "testSingleTestInpAndSaveData": + testBatchInpNum = int(sys.argv[2]) + findAccOrArgMaxOrPredVal = int(sys.argv[3]) + restoreWeights = True + onlysavegraph = False + ( + all_testing_features, + all_testing_labels, + ) = Util.load_preprocess_testing_data() + testing_features, testing_labels = ( + all_testing_features[testBatchInpNum : testBatchInpNum + 1], + all_testing_labels[testBatchInpNum : testBatchInpNum + 1], + ) + # testing_features, testing_labels = numpy.full((1,32,32,3),0.01), numpy.full((1,10),0.01) + elif inp == "savegraphAndDataBatch": + batchNum = int(sys.argv[2]) + imgStartNum = int(sys.argv[3]) + imgEndNum = int(sys.argv[4]) + findAccOrArgMaxOrPredVal = 1 + restoreWeights = False + onlysavegraph = True + testing_features, testing_labels = Util.get_sample_points( + batchNum, imgStartNum, imgEndNum + ) + elif inp == "testBatchInp": + imgStartNum = int(sys.argv[2]) + imgEndNum = int(sys.argv[3]) + findAccOrArgMaxOrPredVal = 1 + restoreWeights = True + onlysavegraph = False + ( + all_testing_features, + all_testing_labels, + ) = Util.load_preprocess_testing_data() + testing_features, testing_labels = ( + all_testing_features[imgStartNum:imgEndNum], + all_testing_labels[imgStartNum:imgEndNum], + ) + elif inp == "findAndSaveCorrectTestImg": + findAccOrArgMaxOrPredVal = 2 + restoreWeights = True + onlysavegraph = False + testing_features, testing_labels = Util.load_preprocess_testing_data() + testing_features, testing_labels = ( + testing_features[0:100], + testing_labels[0:100], + ) + else: + if inp != "": + print("WARNING : Given option didn't match any known value.") + testing_features, testing_labels = Util.load_preprocess_testing_data() + + sqn = SqueezeNet1(use_cons_init=onlysavegraph) + if doTraining: + train(sqn, save_model_path) + else: + with tf.Session() as sess: + pred = infer( + sqn, + sess, + testing_features, + testing_labels, + save_model_path, + findAccOrArgMaxOrPredVal=findAccOrArgMaxOrPredVal, + restoreWeights=restoreWeights, + onlysavegraph=onlysavegraph, + ) + if findAccOrArgMaxOrPredVal == 1 and not (onlysavegraph): + print("Actual labels = ", testing_labels) + print("ArgMax in actual label : ", numpy.argmax(testing_labels, axis=1)) + + if inp == "findAndSaveCorrectTestImg": + print("Running " + inp) + print(pred[0].shape) + findAndSaveCorrectTestImg( + pred, + testing_features, + testing_labels, + "./testPred/CorrectImg/", + "./testPred/IncorrectImg/", + "./testPred/TestInputs/", + sess, + sqn, + scalingFac, + ) + + if inp == "savegraphAndDataBatch" or inp == "testSingleTestInpAndSaveData": + imgFileName = "model_input_scale_{}.inp".format(scalingFac) + weightsFileName = "model_weights_scale_{}.inp".format(scalingFac) + for ii, curFeature in enumerate(testing_features): + if ii == 0: + DumpTFMtData.dumpImageDataInt( + curFeature, imgFileName, scalingFac, "w" + ) + else: + DumpTFMtData.dumpImageDataInt( + curFeature, imgFileName, scalingFac, "a" + ) + DumpTFMtData.dumpTrainedWeightsInt( + sess, sqn.all_weights, weightsFileName, scalingFac, "w" + ) + + +if __name__ == "__main__": + main() diff --git a/Athos/Networks/SqueezeNetCIFAR10/Util.py b/Athos/Networks/SqueezeNetCIFAR10/Util.py index bc368670..51f2cbe4 100644 --- a/Athos/Networks/SqueezeNetCIFAR10/Util.py +++ b/Athos/Networks/SqueezeNetCIFAR10/Util.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -25,7 +25,7 @@ Modified for our purposes. ** -''' +""" import os, sys import pickle @@ -34,70 +34,104 @@ import tensorflow as tf from matplotlib import pyplot as plt -preProcessedImgSaveFolderConst = './PreProcessedImages' +preProcessedImgSaveFolderConst = "./PreProcessedImages" + def load_label_names(): - return ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] + return [ + "airplane", + "automobile", + "bird", + "cat", + "deer", + "dog", + "frog", + "horse", + "ship", + "truck", + ] + def load_cfar10_batch(cifar10_dataset_folder_path, batch_id): - with open(cifar10_dataset_folder_path + '/data_batch_' + str(batch_id), mode='rb') as file: + with open( + cifar10_dataset_folder_path + "/data_batch_" + str(batch_id), mode="rb" + ) as file: # note the encoding type is 'latin1' - batch = pickle.load(file, encoding='latin1') + batch = pickle.load(file, encoding="latin1") - features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1) - labels = batch['labels'] + features = ( + batch["data"].reshape((len(batch["data"]), 3, 32, 32)).transpose(0, 2, 3, 1) + ) + labels = batch["labels"] return features, labels -def display_stats(cifar10_dataset_folder_path, batch_id, sample_id, savepng=False, showfig=False): + +def display_stats( + cifar10_dataset_folder_path, batch_id, sample_id, savepng=False, showfig=False +): features, labels = load_cfar10_batch(cifar10_dataset_folder_path, batch_id) if not (0 <= sample_id < len(features)): - print('{} samples in batch {}. {} is out of range.'.format(len(features), batch_id, sample_id)) + print( + "{} samples in batch {}. {} is out of range.".format( + len(features), batch_id, sample_id + ) + ) return None - print('\nStats of batch #{}:'.format(batch_id)) - print('# of Samples: {}\n'.format(len(features))) + print("\nStats of batch #{}:".format(batch_id)) + print("# of Samples: {}\n".format(len(features))) label_names = load_label_names() label_counts = dict(zip(*np.unique(labels, return_counts=True))) for key, value in label_counts.items(): - print('Label Counts of [{}]({}) : {}'.format(key, label_names[key].upper(), value)) + print( + "Label Counts of [{}]({}) : {}".format(key, label_names[key].upper(), value) + ) sample_image = features[sample_id] sample_label = labels[sample_id] - print('\nExample of Image {}:'.format(sample_id)) - print('Image - Min Value: {} Max Value: {}'.format(sample_image.min(), sample_image.max())) - print('Image - Shape: {}'.format(sample_image.shape)) - print('Label - Label Id: {} Name: {}'.format(sample_label, label_names[sample_label])) + print("\nExample of Image {}:".format(sample_id)) + print( + "Image - Min Value: {} Max Value: {}".format( + sample_image.min(), sample_image.max() + ) + ) + print("Image - Shape: {}".format(sample_image.shape)) + print( + "Label - Label Id: {} Name: {}".format(sample_label, label_names[sample_label]) + ) if savepng or showfig: # Save/show a .png file for the current image plt.imshow(sample_image) if savepng: - plt.savefig('foo.png') + plt.savefig("foo.png") elif showfig: plt.show() + def normalize(x): """ - argument - - x: input image data in numpy array [32, 32, 3] - return - - normalized x + argument + - x: input image data in numpy array [32, 32, 3] + return + - normalized x """ min_val = np.min(x) max_val = np.max(x) - x = (x-min_val) / (max_val-min_val) + x = (x - min_val) / (max_val - min_val) return x + def one_hot_encode(x): """ - argument - - x: a list of labels - return - - one hot encoding matrix (number of labels, number of class) + argument + - x: a list of labels + return + - one hot encoding matrix (number of labels, number of class) """ encoded = np.zeros((len(x), 10)) @@ -106,16 +140,23 @@ def one_hot_encode(x): return encoded + def _preprocess_and_save(normalize, one_hot_encode, features, labels, filename): features = normalize(features) labels = one_hot_encode(labels) - pickle.dump((features, labels), open(filename, 'wb')) + pickle.dump((features, labels), open(filename, "wb")) + # Saved files are 'preprocess_batch_' + str(batch_i) + '.p', # 'preprocess_validation.p', # 'preprocess_testing.p' -def preprocess_and_save_data(cifar10_dataset_folder_path, normalize, one_hot_encode, preProcessedImgSaveFolder = preProcessedImgSaveFolderConst): +def preprocess_and_save_data( + cifar10_dataset_folder_path, + normalize, + one_hot_encode, + preProcessedImgSaveFolder=preProcessedImgSaveFolderConst, +): n_batches = 5 valid_features = [] valid_labels = [] @@ -131,9 +172,15 @@ def preprocess_and_save_data(cifar10_dataset_folder_path, normalize, one_hot_enc # - one_hot_encode the lables # - save in a new file named, "preprocess_batch_" + batch_number # - each file for each batch - _preprocess_and_save(normalize, one_hot_encode, - features[:-index_of_validation], labels[:-index_of_validation], - os.path.join(preProcessedImgSaveFolder, 'preprocess_batch_' + str(batch_i) + '.p')) + _preprocess_and_save( + normalize, + one_hot_encode, + features[:-index_of_validation], + labels[:-index_of_validation], + os.path.join( + preProcessedImgSaveFolder, "preprocess_batch_" + str(batch_i) + ".p" + ), + ) # unlike the training dataset, validation dataset will be added through all batch dataset # - take 10% of the whold dataset of the batch @@ -144,31 +191,62 @@ def preprocess_and_save_data(cifar10_dataset_folder_path, normalize, one_hot_enc valid_labels.extend(labels[-index_of_validation:]) # preprocess the all stacked validation dataset - _preprocess_and_save(normalize, one_hot_encode, - np.array(valid_features), np.array(valid_labels), - os.path.join(preProcessedImgSaveFolder, 'preprocess_validation.p')) + _preprocess_and_save( + normalize, + one_hot_encode, + np.array(valid_features), + np.array(valid_labels), + os.path.join(preProcessedImgSaveFolder, "preprocess_validation.p"), + ) # load the test dataset - with open(cifar10_dataset_folder_path + '/test_batch', mode='rb') as file: - batch = pickle.load(file, encoding='latin1') + with open(cifar10_dataset_folder_path + "/test_batch", mode="rb") as file: + batch = pickle.load(file, encoding="latin1") # preprocess the testing data - test_features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1) - test_labels = batch['labels'] + test_features = ( + batch["data"].reshape((len(batch["data"]), 3, 32, 32)).transpose(0, 2, 3, 1) + ) + test_labels = batch["labels"] # Preprocess and Save all testing data - _preprocess_and_save(normalize, one_hot_encode, - np.array(test_features), np.array(test_labels), - os.path.join(preProcessedImgSaveFolder, 'preprocess_testing.p')) - -def get_one_sample_point(batch_id, sample_id, preProcessedImgSaveFolder = preProcessedImgSaveFolderConst): - features, labels = pickle.load(open(os.path.join(preProcessedImgSaveFolder, 'preprocess_batch_' + str(batch_id) + '.p'), mode='rb')) + _preprocess_and_save( + normalize, + one_hot_encode, + np.array(test_features), + np.array(test_labels), + os.path.join(preProcessedImgSaveFolder, "preprocess_testing.p"), + ) + + +def get_one_sample_point( + batch_id, sample_id, preProcessedImgSaveFolder=preProcessedImgSaveFolderConst +): + features, labels = pickle.load( + open( + os.path.join( + preProcessedImgSaveFolder, "preprocess_batch_" + str(batch_id) + ".p" + ), + mode="rb", + ) + ) return (features[sample_id], labels[sample_id]) -def get_sample_points(batch_id, start_id, end_id, preProcessedImgSaveFolder = preProcessedImgSaveFolderConst): - features, labels = pickle.load(open(os.path.join(preProcessedImgSaveFolder, 'preprocess_batch_' + str(batch_id) + '.p'), mode='rb')) + +def get_sample_points( + batch_id, start_id, end_id, preProcessedImgSaveFolder=preProcessedImgSaveFolderConst +): + features, labels = pickle.load( + open( + os.path.join( + preProcessedImgSaveFolder, "preprocess_batch_" + str(batch_id) + ".p" + ), + mode="rb", + ) + ) return (features[start_id:end_id], labels[start_id:end_id]) + def batch_features_labels(features, labels, batch_size): """ Split features and labels into batches @@ -177,35 +255,60 @@ def batch_features_labels(features, labels, batch_size): end = min(start + batch_size, len(features)) yield features[start:end], labels[start:end] -def load_preprocess_training_batch(batch_id, batch_size, preProcessedImgSaveFolder = preProcessedImgSaveFolderConst): + +def load_preprocess_training_batch( + batch_id, batch_size, preProcessedImgSaveFolder=preProcessedImgSaveFolderConst +): """ Load the Preprocessed Training data and return them in batches of or less """ - filename = os.path.join(preProcessedImgSaveFolder, 'preprocess_batch_' + str(batch_id) + '.p') - features, labels = pickle.load(open(filename, mode='rb')) + filename = os.path.join( + preProcessedImgSaveFolder, "preprocess_batch_" + str(batch_id) + ".p" + ) + features, labels = pickle.load(open(filename, mode="rb")) # Return the training data in batches of size or less return batch_features_labels(features, labels, batch_size) -def load_preprocess_training_data(batch_id, preProcessedImgSaveFolder = preProcessedImgSaveFolderConst): - filename = os.path.join(preProcessedImgSaveFolder, 'preprocess_batch_' + str(batch_id) + '.p') - features, labels = pickle.load(open(filename, mode='rb')) + +def load_preprocess_training_data( + batch_id, preProcessedImgSaveFolder=preProcessedImgSaveFolderConst +): + filename = os.path.join( + preProcessedImgSaveFolder, "preprocess_batch_" + str(batch_id) + ".p" + ) + features, labels = pickle.load(open(filename, mode="rb")) return features, labels -def load_preprocess_validation_data(preProcessedImgSaveFolder = preProcessedImgSaveFolderConst): - valid_features, valid_labels = pickle.load(open(os.path.join(preProcessedImgSaveFolder, 'preprocess_validation.p'), mode='rb')) + +def load_preprocess_validation_data( + preProcessedImgSaveFolder=preProcessedImgSaveFolderConst, +): + valid_features, valid_labels = pickle.load( + open( + os.path.join(preProcessedImgSaveFolder, "preprocess_validation.p"), + mode="rb", + ) + ) return valid_features, valid_labels -def load_preprocess_testing_data(preProcessedImgSaveFolder = preProcessedImgSaveFolderConst): - testing_features, testing_labels = pickle.load(open(os.path.join(preProcessedImgSaveFolder, 'preprocess_testing.p'), mode='rb')) + +def load_preprocess_testing_data( + preProcessedImgSaveFolder=preProcessedImgSaveFolderConst, +): + testing_features, testing_labels = pickle.load( + open(os.path.join(preProcessedImgSaveFolder, "preprocess_testing.p"), mode="rb") + ) return testing_features, testing_labels + def main(): - cifar10_dataset_folder_path = '../../HelperScripts/CIFAR10/cifar-10-batches-py' - preProcessedImgSaveFolder = './PreProcessedImages' + cifar10_dataset_folder_path = "../../HelperScripts/CIFAR10/cifar-10-batches-py" + preProcessedImgSaveFolder = "./PreProcessedImages" preprocess_and_save_data(cifar10_dataset_folder_path, normalize, one_hot_encode) display_stats(cifar10_dataset_folder_path, 2, 4555) print(get_one_sample_point(2, 4555)) -if __name__ == '__main__': - main() + +if __name__ == "__main__": + main() diff --git a/Athos/Networks/SqueezeNetImgNet/AccuracyAnalysisHelper/SqueezeNet_main_float_acc.py b/Athos/Networks/SqueezeNetImgNet/AccuracyAnalysisHelper/SqueezeNet_main_float_acc.py index 48d4e451..ed7f87ab 100644 --- a/Athos/Networks/SqueezeNetImgNet/AccuracyAnalysisHelper/SqueezeNet_main_float_acc.py +++ b/Athos/Networks/SqueezeNetImgNet/AccuracyAnalysisHelper/SqueezeNet_main_float_acc.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import numpy import argparse @@ -28,46 +28,55 @@ import tensorflow as tf import _pickle as pickle -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) import squeezenet_main as sqzmain batchsize = 100 N = 200 -data, sqz_mean = sqzmain.load_net('../PreTrainedModel/sqz_full.mat') -image = tf.placeholder(dtype=sqzmain.get_dtype_tf(), shape=(None,227,227,3), name="image_placeholder") +data, sqz_mean = sqzmain.load_net("../PreTrainedModel/sqz_full.mat") +image = tf.placeholder( + dtype=sqzmain.get_dtype_tf(), shape=(None, 227, 227, 3), name="image_placeholder" +) keep_prob = 0.0 -sqznet = sqzmain.net_preloaded(data, image, 'max', True, keep_prob) +sqznet = sqzmain.net_preloaded(data, image, "max", True, keep_prob) -finalActivationsFileName = 'floating_point_acc.outp' -argmaxOutputFileName = 'floating_point_argmax.outp' +finalActivationsFileName = "floating_point_acc.outp" +argmaxOutputFileName = "floating_point_argmax.outp" with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - - with open(finalActivationsFileName,'w') as ff: - pass - with open(argmaxOutputFileName,'w') as ff: - pass - numbatches = N//batchsize - for batchNum in range(numbatches): - startImgNum = (batchNum*batchsize) + 1 - endImgNum = N if (batchNum == numbatches-1) else (((batchNum+1)*batchsize)) - print("Processing images from start,end = {0}, {1}".format(startImgNum, endImgNum)) - images = numpy.zeros(shape=(endImgNum-startImgNum+1,227,227,3)) - for curImgNum in range(startImgNum, endImgNum+1): - with open('./PreProcessedImages/ImageNum_'+str(curImgNum)+'.inp', 'r') as ff: - line = ff.readline() - images[curImgNum-startImgNum] = numpy.reshape(list(map(lambda x : float(x), line.split())), (227,227,3)) - feed_dict = {image: images} - predictions = sess.run(sqznet['classifier_pool'], feed_dict=feed_dict) - with open(finalActivationsFileName, 'a') as ff: - with open(argmaxOutputFileName, 'a') as gg: - for i in range(endImgNum-startImgNum+1): - ff.write('Answer for imgCounter = ' + str(startImgNum+i) + '\n') - for elem in numpy.nditer(predictions[i],order='C'): - ff.write(str(elem)+' ') - ff.write('\n\n') - gg.write('Answer for imgCounter = '+str(startImgNum+i)+' is ') - gg.write(str(numpy.argmax(predictions[i], 2))+'\n') + sess.run(tf.global_variables_initializer()) + with open(finalActivationsFileName, "w") as ff: + pass + with open(argmaxOutputFileName, "w") as ff: + pass + numbatches = N // batchsize + for batchNum in range(numbatches): + startImgNum = (batchNum * batchsize) + 1 + endImgNum = ( + N if (batchNum == numbatches - 1) else (((batchNum + 1) * batchsize)) + ) + print( + "Processing images from start,end = {0}, {1}".format(startImgNum, endImgNum) + ) + images = numpy.zeros(shape=(endImgNum - startImgNum + 1, 227, 227, 3)) + for curImgNum in range(startImgNum, endImgNum + 1): + with open( + "./PreProcessedImages/ImageNum_" + str(curImgNum) + ".inp", "r" + ) as ff: + line = ff.readline() + images[curImgNum - startImgNum] = numpy.reshape( + list(map(lambda x: float(x), line.split())), (227, 227, 3) + ) + feed_dict = {image: images} + predictions = sess.run(sqznet["classifier_pool"], feed_dict=feed_dict) + with open(finalActivationsFileName, "a") as ff: + with open(argmaxOutputFileName, "a") as gg: + for i in range(endImgNum - startImgNum + 1): + ff.write("Answer for imgCounter = " + str(startImgNum + i) + "\n") + for elem in numpy.nditer(predictions[i], order="C"): + ff.write(str(elem) + " ") + ff.write("\n\n") + gg.write("Answer for imgCounter = " + str(startImgNum + i) + " is ") + gg.write(str(numpy.argmax(predictions[i], 2)) + "\n") diff --git a/Athos/Networks/SqueezeNetImgNet/PreProcessingImages/SqNetImgNet_preprocess_main.py b/Athos/Networks/SqueezeNetImgNet/PreProcessingImages/SqNetImgNet_preprocess_main.py index 206b5736..f6750212 100644 --- a/Athos/Networks/SqueezeNetImgNet/PreProcessingImages/SqNetImgNet_preprocess_main.py +++ b/Athos/Networks/SqueezeNetImgNet/PreProcessingImages/SqNetImgNet_preprocess_main.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -24,7 +24,7 @@ Parts of this code were used from https://github.com/avoroshilov/tf-squeezenet. ** -''' +""" import os, sys import numpy as np @@ -34,18 +34,21 @@ import time import numpy + def imread_resize(path): - #img_orig =imread(path) + # img_orig =imread(path) img_orig = Image.open(path).convert("RGB") img_orig = numpy.asarray(img_orig) img = scipy.misc.imresize(img_orig, (227, 227)).astype(np.float) if len(img.shape) == 2: # grayscale - img = np.dstack((img,img,img)) + img = np.dstack((img, img, img)) return img, img_orig.shape + mean_pixel = np.array([104.006, 116.669, 122.679], dtype=np.float32) + def preprocess(image, mean_pixel): swap_img = np.array(image) img_out = np.array(swap_img) @@ -53,44 +56,62 @@ def preprocess(image, mean_pixel): img_out[:, :, 2] = swap_img[:, :, 0] return img_out - mean_pixel + def dumpImageDataFloat(imgData, filename, writeMode): - with open(filename, writeMode) as ff: - for xx in numpy.nditer(imgData, order='C'): - ff.write(str(xx) + ' ') - ff.write('\n\n') + with open(filename, writeMode) as ff: + for xx in numpy.nditer(imgData, order="C"): + ff.write(str(xx) + " ") + ff.write("\n\n") + def main(): - if not((len(sys.argv) >= 7) and (len(sys.argv) <= 8)): - print("Args : ?", file=sys.stderr) - exit(1) - - imgFolderName = sys.argv[1] - bboxFolderName = sys.argv[2] - fileNamePrefix = sys.argv[3] - preProcessedImgFolderName = sys.argv[4] - firstImgNum = int(sys.argv[5]) - lastImgNum = int(sys.argv[6]) - randomSubsetIdxFile = None - if (len(sys.argv) == 8): - randomSubsetIdxFile = sys.argv[7] - - randomIdxToBeChosen = None - if randomSubsetIdxFile: - with open(randomSubsetIdxFile, 'r') as ff: - randomIdxToBeChosen = ff.readlines() - randomIdxToBeChosen = list(map(lambda x : int(x.rstrip()), randomIdxToBeChosen)) - assert(lastImgNum <= len(randomIdxToBeChosen)+1) #Assert that the last img num passed is within bounds - - for curImgNum in range(firstImgNum, lastImgNum): - if (curImgNum % 100 == 0): - print("CurImgNum = ", curImgNum) - actualImgNum = curImgNum if not(randomIdxToBeChosen) else randomIdxToBeChosen[curImgNum-1] - imgFileName = os.path.join(imgFolderName, fileNamePrefix + "{:08d}".format(actualImgNum) + '.JPEG') - imgData, imgShape = imread_resize(imgFileName) - preprocessed_image_buffer = preprocess(imgData, mean_pixel) - - saveFilePath = os.path.join(preProcessedImgFolderName, 'ImageNum_' + str(actualImgNum) + '.inp') - dumpImageDataFloat(preprocessed_image_buffer, saveFilePath, 'w') - -if __name__=='__main__': - main() + if not ((len(sys.argv) >= 7) and (len(sys.argv) <= 8)): + print( + "Args : ?", + file=sys.stderr, + ) + exit(1) + + imgFolderName = sys.argv[1] + bboxFolderName = sys.argv[2] + fileNamePrefix = sys.argv[3] + preProcessedImgFolderName = sys.argv[4] + firstImgNum = int(sys.argv[5]) + lastImgNum = int(sys.argv[6]) + randomSubsetIdxFile = None + if len(sys.argv) == 8: + randomSubsetIdxFile = sys.argv[7] + + randomIdxToBeChosen = None + if randomSubsetIdxFile: + with open(randomSubsetIdxFile, "r") as ff: + randomIdxToBeChosen = ff.readlines() + randomIdxToBeChosen = list( + map(lambda x: int(x.rstrip()), randomIdxToBeChosen) + ) + assert ( + lastImgNum <= len(randomIdxToBeChosen) + 1 + ) # Assert that the last img num passed is within bounds + + for curImgNum in range(firstImgNum, lastImgNum): + if curImgNum % 100 == 0: + print("CurImgNum = ", curImgNum) + actualImgNum = ( + curImgNum + if not (randomIdxToBeChosen) + else randomIdxToBeChosen[curImgNum - 1] + ) + imgFileName = os.path.join( + imgFolderName, fileNamePrefix + "{:08d}".format(actualImgNum) + ".JPEG" + ) + imgData, imgShape = imread_resize(imgFileName) + preprocessed_image_buffer = preprocess(imgData, mean_pixel) + + saveFilePath = os.path.join( + preProcessedImgFolderName, "ImageNum_" + str(actualImgNum) + ".inp" + ) + dumpImageDataFloat(preprocessed_image_buffer, saveFilePath, "w") + + +if __name__ == "__main__": + main() diff --git a/Athos/Networks/SqueezeNetImgNet/squeezenet_main.py b/Athos/Networks/SqueezeNetImgNet/squeezenet_main.py index 8b6cbc9b..3c074936 100644 --- a/Athos/Networks/SqueezeNetImgNet/squeezenet_main.py +++ b/Athos/Networks/SqueezeNetImgNet/squeezenet_main.py @@ -12,6 +12,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) from tensorflow.python.util import deprecation + deprecation._PRINT_DEPRECATION_WARNINGS = False try: from tensorflow.python.util import module_wrapper as deprecation @@ -19,60 +20,69 @@ from tensorflow.python.util import deprecation_wrapper as deprecation deprecation._PER_MODULE_WARNING_LIMIT = 0 -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'TFCompiler')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "TFCompiler")) import DumpTFMtData + def imread_resize(path): img_orig = Image.open(path).convert("RGB") img_orig = np.asarray(img_orig) - - # NOTE: scipy.misc.imresize is deprecated in > v1.1.0. - # But i cannot find a suitable replacement for this which returns + + # NOTE: scipy.misc.imresize is deprecated in > v1.1.0. + # But i cannot find a suitable replacement for this which returns # exactly the same float value after resizing as scipy.misc.resize. # So, as an alternative, try reinstalling scipy v1.1.0 and then run this code. # Install Scipy v1.1 as : pip3 install scipy==1.1.0 - img = scipy.misc.imresize(img_orig, (227, 227)).astype(np.float) + img = scipy.misc.imresize(img_orig, (227, 227)).astype(np.float) if len(img.shape) == 2: # grayscale - img = np.dstack((img,img,img)) + img = np.dstack((img, img, img)) return img, img_orig.shape + def imsave(path, img): img = np.clip(img, 0, 255).astype(np.uint8) Image.fromarray(img).save(path, quality=95) - + + def get_dtype_np(): return np.float32 + def get_dtype_tf(): return tf.float32 - + + # SqueezeNet v1.1 (signature pool 1/3/5) ######################################## all_weights = [] + def load_net(data_path): if not os.path.isfile(data_path): - parser.error("Network %s does not exist. (Did you forget to download it?)" % data_path) + parser.error( + "Network %s does not exist. (Did you forget to download it?)" % data_path + ) weights_raw = scipy.io.loadmat(data_path) - + # Converting to needed type conv_time = time.time() weights = {} for name in weights_raw: weights[name] = [] # skipping '__version__', '__header__', '__globals__' - if name[0:2] != '__': + if name[0:2] != "__": kernels, bias = weights_raw[name][0] - weights[name].append( kernels.astype(get_dtype_np()) ) - weights[name].append( bias.astype(get_dtype_np()) ) + weights[name].append(kernels.astype(get_dtype_np())) + weights[name].append(bias.astype(get_dtype_np())) print("Converted network data(%s): %fs" % (get_dtype_np(), time.time() - conv_time)) - + mean_pixel = np.array([104.006, 116.669, 122.679], dtype=get_dtype_np()) return weights, mean_pixel + def preprocess(image, mean_pixel): swap_img = np.array(image) img_out = np.array(swap_img) @@ -80,6 +90,7 @@ def preprocess(image, mean_pixel): img_out[:, :, 2] = swap_img[:, :, 0] return img_out - mean_pixel + def unprocess(image, mean_pixel): swap_img = np.array(image + mean_pixel) img_out = np.array(swap_img) @@ -87,37 +98,71 @@ def unprocess(image, mean_pixel): img_out[:, :, 2] = swap_img[:, :, 0] return img_out + def get_weights_biases(preloaded, layer_name): weights, biases = preloaded[layer_name] biases = biases.reshape(-1) return (weights, biases) + def fire_cluster(net, x, preloaded, cluster_name, runPrediction=True): # central - squeeze - layer_name = cluster_name + '/squeeze1x1' + layer_name = cluster_name + "/squeeze1x1" weights, biases = get_weights_biases(preloaded, layer_name) - x = _conv_layer(net, layer_name + '_conv', x, weights, biases, padding='VALID', runPrediction=runPrediction) - x = _act_layer(net, layer_name + '_actv', x) - + x = _conv_layer( + net, + layer_name + "_conv", + x, + weights, + biases, + padding="VALID", + runPrediction=runPrediction, + ) + x = _act_layer(net, layer_name + "_actv", x) + # left - expand 1x1 - layer_name = cluster_name + '/expand1x1' + layer_name = cluster_name + "/expand1x1" weights, biases = get_weights_biases(preloaded, layer_name) - x_l = _conv_layer(net, layer_name + '_conv', x, weights, biases, padding='VALID', runPrediction=runPrediction) - x_l = _act_layer(net, layer_name + '_actv', x_l) + x_l = _conv_layer( + net, + layer_name + "_conv", + x, + weights, + biases, + padding="VALID", + runPrediction=runPrediction, + ) + x_l = _act_layer(net, layer_name + "_actv", x_l) # right - expand 3x3 - layer_name = cluster_name + '/expand3x3' + layer_name = cluster_name + "/expand3x3" weights, biases = get_weights_biases(preloaded, layer_name) - x_r = _conv_layer(net, layer_name + '_conv', x, weights, biases, padding='SAME', runPrediction=runPrediction) - x_r = _act_layer(net, layer_name + '_actv', x_r) - + x_r = _conv_layer( + net, + layer_name + "_conv", + x, + weights, + biases, + padding="SAME", + runPrediction=runPrediction, + ) + x_r = _act_layer(net, layer_name + "_actv", x_r) + # concatenate expand 1x1 (left) and expand 3x3 (right) x = tf.concat([x_l, x_r], 3) - net[cluster_name + '/concat_conc'] = x - + net[cluster_name + "/concat_conc"] = x + return x -def net_preloaded(preloaded, input_image, pooling, needs_classifier=False, keep_prob=None, runPrediction=True): + +def net_preloaded( + preloaded, + input_image, + pooling, + needs_classifier=False, + keep_prob=None, + runPrediction=True, +): net = {} cr_time = time.time() @@ -125,98 +170,195 @@ def net_preloaded(preloaded, input_image, pooling, needs_classifier=False, keep_ # Feature extractor ##################### - + # conv1 cluster - layer_name = 'conv1' + layer_name = "conv1" weights, biases = get_weights_biases(preloaded, layer_name) - x = _conv_layer(net, layer_name + '_conv', x, weights, biases, padding='VALID', stride=(2, 2), runPrediction=runPrediction) - x = _act_layer(net, layer_name + '_actv', x) - x = _pool_layer(net, 'pool1_pool', x, pooling, size=(3, 3), stride=(2, 2), padding='VALID') + x = _conv_layer( + net, + layer_name + "_conv", + x, + weights, + biases, + padding="VALID", + stride=(2, 2), + runPrediction=runPrediction, + ) + x = _act_layer(net, layer_name + "_actv", x) + x = _pool_layer( + net, "pool1_pool", x, pooling, size=(3, 3), stride=(2, 2), padding="VALID" + ) # fire2 + fire3 clusters - x = fire_cluster(net, x, preloaded, cluster_name='fire2', runPrediction=runPrediction) + x = fire_cluster( + net, x, preloaded, cluster_name="fire2", runPrediction=runPrediction + ) fire2_bypass = x - x = fire_cluster(net, x, preloaded, cluster_name='fire3', runPrediction=runPrediction) - x = _pool_layer(net, 'pool3_pool', x, pooling, size=(3, 3), stride=(2, 2), padding='VALID') + x = fire_cluster( + net, x, preloaded, cluster_name="fire3", runPrediction=runPrediction + ) + x = _pool_layer( + net, "pool3_pool", x, pooling, size=(3, 3), stride=(2, 2), padding="VALID" + ) # fire4 + fire5 clusters - x = fire_cluster(net, x, preloaded, cluster_name='fire4', runPrediction=runPrediction) + x = fire_cluster( + net, x, preloaded, cluster_name="fire4", runPrediction=runPrediction + ) fire4_bypass = x - x = fire_cluster(net, x, preloaded, cluster_name='fire5', runPrediction=runPrediction) - x = _pool_layer(net, 'pool5_pool', x, pooling, size=(3, 3), stride=(2, 2), padding='VALID') + x = fire_cluster( + net, x, preloaded, cluster_name="fire5", runPrediction=runPrediction + ) + x = _pool_layer( + net, "pool5_pool", x, pooling, size=(3, 3), stride=(2, 2), padding="VALID" + ) # remainder (no pooling) - x = fire_cluster(net, x, preloaded, cluster_name='fire6', runPrediction=runPrediction) + x = fire_cluster( + net, x, preloaded, cluster_name="fire6", runPrediction=runPrediction + ) fire6_bypass = x - x = fire_cluster(net, x, preloaded, cluster_name='fire7', runPrediction=runPrediction) - x = fire_cluster(net, x, preloaded, cluster_name='fire8', runPrediction=runPrediction) - x = fire_cluster(net, x, preloaded, cluster_name='fire9', runPrediction=runPrediction) - + x = fire_cluster( + net, x, preloaded, cluster_name="fire7", runPrediction=runPrediction + ) + x = fire_cluster( + net, x, preloaded, cluster_name="fire8", runPrediction=runPrediction + ) + x = fire_cluster( + net, x, preloaded, cluster_name="fire9", runPrediction=runPrediction + ) + # Classifier ##################### if needs_classifier == True: # Dropout [use value of 50% when training] # x = tf.nn.dropout(x, keep_prob) - + # Fixed global avg pool/softmax classifier: # [227, 227, 3] -> 1000 classes - layer_name = 'conv10' + layer_name = "conv10" weights, biases = get_weights_biases(preloaded, layer_name) - x = _conv_layer(net, layer_name + '_conv', x, weights, biases, runPrediction=runPrediction) - x = _act_layer(net, layer_name + '_actv', x) - + x = _conv_layer( + net, layer_name + "_conv", x, weights, biases, runPrediction=runPrediction + ) + x = _act_layer(net, layer_name + "_actv", x) + # Global Average Pooling - x = tf.nn.avg_pool(x, ksize=(1, 13, 13, 1), strides=(1, 1, 1, 1), padding='VALID') - net['classifier_pool'] = x - + x = tf.nn.avg_pool( + x, ksize=(1, 13, 13, 1), strides=(1, 1, 1, 1), padding="VALID" + ) + net["classifier_pool"] = x + # x = tf.nn.softmax(x) # net['classifier_actv'] = x - + print("Network instance created: %fs" % (time.time() - cr_time)) - + return net - -def _conv_layer(net, name, input, weights, bias, padding='SAME', stride=(1, 1), runPrediction=True): + + +def _conv_layer( + net, name, input, weights, bias, padding="SAME", stride=(1, 1), runPrediction=True +): global all_weights if runPrediction: - conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, stride[0], stride[1], 1), - padding=padding) + conv = tf.nn.conv2d( + input, + tf.constant(weights), + strides=(1, stride[0], stride[1], 1), + padding=padding, + ) x = tf.nn.bias_add(conv, bias) else: - conv = tf.nn.conv2d(input, tf.Variable(tf.constant(0.1,shape=weights.shape)), strides=(1, stride[0], stride[1], 1), - padding=padding) - x = tf.nn.bias_add(conv, tf.Variable(tf.constant(0.1,shape=bias.shape))) + conv = tf.nn.conv2d( + input, + tf.Variable(tf.constant(0.1, shape=weights.shape)), + strides=(1, stride[0], stride[1], 1), + padding=padding, + ) + x = tf.nn.bias_add(conv, tf.Variable(tf.constant(0.1, shape=bias.shape))) net[name] = x all_weights.append(weights) all_weights.append(bias) return x + def _act_layer(net, name, input): x = tf.nn.relu(input) net[name] = x return x - -def _pool_layer(net, name, input, pooling, size=(2, 2), stride=(3, 3), padding='SAME'): - if pooling == 'avg': - x = tf.nn.avg_pool(input, ksize=(1, size[0], size[1], 1), strides=(1, stride[0], stride[1], 1), - padding=padding) + + +def _pool_layer(net, name, input, pooling, size=(2, 2), stride=(3, 3), padding="SAME"): + if pooling == "avg": + x = tf.nn.avg_pool( + input, + ksize=(1, size[0], size[1], 1), + strides=(1, stride[0], stride[1], 1), + padding=padding, + ) else: - x = tf.nn.max_pool(input, ksize=(1, size[0], size[1], 1), strides=(1, stride[0], stride[1], 1), - padding=padding) + x = tf.nn.max_pool( + input, + ksize=(1, size[0], size[1], 1), + strides=(1, stride[0], stride[1], 1), + padding=padding, + ) net[name] = x return x + def build_parser(): ps = ArgumentParser() - ps.add_argument('--in', dest='input', help='input file', metavar='INPUT', required=True) - ps.add_argument('--saveTFMetadata', dest='saveTFMetadata', type=bool, help='bool to indicate if to save metadata', required=False) - ps.add_argument('--saveImgAndWtData', dest='saveImgAndWtData', type=bool, help='bool to indicate if to save img and model weights', required=False) - ps.add_argument('--savePreTrainedWeightsFloat', dest='savePreTrainedWeightsFloat', type=bool, help='bool to indicate if to save model weights float', required=False) - ps.add_argument('--savePreTrainedWeightsInt', dest='savePreTrainedWeightsInt', type=bool, help='bool to indicate if to save model weights int', required=False) - ps.add_argument('--saveImgAndWeightsSeparately', dest='saveImgAndWeightsSeparately', type=bool, help='bool to indicate if to save image and model weights int separately', required=False) - ps.add_argument('--scalingFac', dest='scalingFac', type=int, help='scalingFac', default=15, required=False) + ps.add_argument( + "--in", dest="input", help="input file", metavar="INPUT", required=True + ) + ps.add_argument( + "--saveTFMetadata", + dest="saveTFMetadata", + type=bool, + help="bool to indicate if to save metadata", + required=False, + ) + ps.add_argument( + "--saveImgAndWtData", + dest="saveImgAndWtData", + type=bool, + help="bool to indicate if to save img and model weights", + required=False, + ) + ps.add_argument( + "--savePreTrainedWeightsFloat", + dest="savePreTrainedWeightsFloat", + type=bool, + help="bool to indicate if to save model weights float", + required=False, + ) + ps.add_argument( + "--savePreTrainedWeightsInt", + dest="savePreTrainedWeightsInt", + type=bool, + help="bool to indicate if to save model weights int", + required=False, + ) + ps.add_argument( + "--saveImgAndWeightsSeparately", + dest="saveImgAndWeightsSeparately", + type=bool, + help="bool to indicate if to save image and model weights int separately", + required=False, + ) + ps.add_argument( + "--scalingFac", + dest="scalingFac", + type=int, + help="scalingFac", + default=15, + required=False, + ) return ps + def main(): import time @@ -229,29 +371,33 @@ def main(): # Loading ImageNet classes info classes = [] - with open('synset_words.txt', 'r') as classes_file: + with open("synset_words.txt", "r") as classes_file: classes = classes_file.read().splitlines() # Loading network - data, sqz_mean = load_net('./PreTrainedModel/sqz_full.mat') + data, sqz_mean = load_net("./PreTrainedModel/sqz_full.mat") - config = tf.ConfigProto(log_device_placement = False) + config = tf.ConfigProto(log_device_placement=False) config.gpu_options.allow_growth = True - config.gpu_options.allocator_type = 'BFC' + config.gpu_options.allocator_type = "BFC" g = tf.Graph() - + # 1st pass - simple classification with g.as_default(), tf.Session(config=config) as sess: # Building network - image = tf.placeholder(dtype=get_dtype_tf(), shape=img_content_shape, name="image_placeholder") + image = tf.placeholder( + dtype=get_dtype_tf(), shape=img_content_shape, name="image_placeholder" + ) # keep_prob = tf.placeholder(get_dtype_tf()) keep_prob = 0.0 saveTFMetadata = False if options.saveTFMetadata: saveTFMetadata = options.saveTFMetadata - sqznet = net_preloaded(data, image, 'max', True, keep_prob, runPrediction=not(saveTFMetadata)) - final_class = tf.argmax(sqznet['classifier_pool'],3) + sqznet = net_preloaded( + data, image, "max", True, keep_prob, runPrediction=not (saveTFMetadata) + ) + final_class = tf.argmax(sqznet["classifier_pool"], 3) sess.run(tf.global_variables_initializer()) imageData = [preprocess(img_content, sqz_mean)] @@ -261,11 +407,13 @@ def main(): output_tensor = None gg = tf.get_default_graph() for node in gg.as_graph_def().node: - if node.name == 'ArgMax': # Final activation, not the argmax + if node.name == "ArgMax": # Final activation, not the argmax output_tensor = gg.get_operation_by_name(node.name).outputs[0] - assert(output_tensor is not None) + assert output_tensor is not None if options.saveTFMetadata: - optimized_graph_def = DumpTFMtData.save_graph_metadata(output_tensor, sess, feed_dict) + optimized_graph_def = DumpTFMtData.save_graph_metadata( + output_tensor, sess, feed_dict + ) else: # Classifying print("*************** Starting Prediction****************") @@ -275,20 +423,42 @@ def main(): print("*************** Done Prediction****************") duration = end_time - start_time print("Time taken in inference : ", duration) - with open('tf_pred.float','w+') as f: + with open("tf_pred.float", "w+") as f: f.write(DumpTFMtData.numpy_float_array_to_float_val_str(sqz_class)) - with open('tf_pred.time','w') as f: - f.write(str(round(duration, 2))) + with open("tf_pred.time", "w") as f: + f.write(str(round(duration, 2))) # Outputting result print("\nclass: [%d] '%s'" % (sqz_class, classes[sqz_class])) if options.savePreTrainedWeightsInt: - DumpTFMtData.dumpTrainedWeightsInt(sess, all_weights, "model_weights_scale_{}.inp".format(options.scalingFac), options.scalingFac, 'w', alreadyEvaluated=True) + DumpTFMtData.dumpTrainedWeightsInt( + sess, + all_weights, + "model_weights_scale_{}.inp".format(options.scalingFac), + options.scalingFac, + "w", + alreadyEvaluated=True, + ) if options.savePreTrainedWeightsFloat: - DumpTFMtData.dumpTrainedWeightsFloat(sess, all_weights, 'model_weights_float.inp', 'w', alreadyEvaluated=True) + DumpTFMtData.dumpTrainedWeightsFloat( + sess, + all_weights, + "model_weights_float.inp", + "w", + alreadyEvaluated=True, + ) if options.saveImgAndWtData: - DumpTFMtData.dumpImgAndWeightsDataSeparate(sess, imageData, all_weights, "model_input_scale_{}.inp".format(options.scalingFac), "model_weights_scale_{}.inp".format(options.scalingFac), options.scalingFac, alreadyEvaluated=True) - -if __name__ == '__main__': - main() \ No newline at end of file + DumpTFMtData.dumpImgAndWeightsDataSeparate( + sess, + imageData, + all_weights, + "model_input_scale_{}.inp".format(options.scalingFac), + "model_weights_scale_{}.inp".format(options.scalingFac), + options.scalingFac, + alreadyEvaluated=True, + ) + + +if __name__ == "__main__": + main() diff --git a/Athos/ONNXCompiler/ONNXNodesAST.py b/Athos/ONNXCompiler/ONNXNodesAST.py index a1ec36f7..1cbc28ee 100644 --- a/Athos/ONNXCompiler/ONNXNodesAST.py +++ b/Athos/ONNXCompiler/ONNXNodesAST.py @@ -1,4 +1,4 @@ -''' +""" Authors: Shubham Ugare. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import AST.AST as AST from onnx import mapping @@ -28,24 +28,29 @@ from numbers import Number DEBUG = False -out_var_prefix = 'J' +out_var_prefix = "J" + class OnnxNode(object): - """ - Reimplementation of NodeProto from ONNX, but in a form - more convenient to work with from Python. - """ - - def __init__(self, node): - self.name = str(node.name) - self.op_type = str(node.op_type) - self.domain = str(node.domain) - self.attrs = dict([(attr.name, - translate_onnx(attr.name, convert_onnx(attr))) - for attr in node.attribute]) - self.inputs = list(node.input) - self.outputs = list(node.output) - self.node_proto = node + """ + Reimplementation of NodeProto from ONNX, but in a form + more convenient to work with from Python. + """ + + def __init__(self, node): + self.name = str(node.name) + self.op_type = str(node.op_type) + self.domain = str(node.domain) + self.attrs = dict( + [ + (attr.name, translate_onnx(attr.name, convert_onnx(attr))) + for attr in node.attribute + ] + ) + self.inputs = list(node.input) + self.outputs = list(node.output) + self.node_proto = node + __onnx_attr_translator = { "axis": lambda x: int(x), @@ -57,842 +62,1360 @@ def __init__(self, node): def convert_onnx(attr): - return __convert_onnx_attribute_proto(attr) + return __convert_onnx_attribute_proto(attr) def __convert_onnx_attribute_proto(attr_proto): - """ - Convert an ONNX AttributeProto into an appropriate Python object - for the type. - NB: Tensor attribute gets returned as the straight proto. - """ - if attr_proto.HasField('f'): - return attr_proto.f - elif attr_proto.HasField('i'): - return attr_proto.i - elif attr_proto.HasField('s'): - return str(attr_proto.s, 'utf-8') - elif attr_proto.HasField('t'): - return attr_proto.t # this is a proto! - elif attr_proto.HasField('g'): - return attr_proto.g - elif attr_proto.floats: - return list(attr_proto.floats) - elif attr_proto.ints: - return list(attr_proto.ints) - elif attr_proto.strings: - str_list = list(attr_proto.strings) - if IS_PYTHON3: - str_list = list(map(lambda x: str(x, 'utf-8'), str_list)) - return str_list - elif attr_proto.HasField('sparse_tensor'): - return attr_proto.sparse_tensor - else: - raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto)) + """ + Convert an ONNX AttributeProto into an appropriate Python object + for the type. + NB: Tensor attribute gets returned as the straight proto. + """ + if attr_proto.HasField("f"): + return attr_proto.f + elif attr_proto.HasField("i"): + return attr_proto.i + elif attr_proto.HasField("s"): + return str(attr_proto.s, "utf-8") + elif attr_proto.HasField("t"): + return attr_proto.t # this is a proto! + elif attr_proto.HasField("g"): + return attr_proto.g + elif attr_proto.floats: + return list(attr_proto.floats) + elif attr_proto.ints: + return list(attr_proto.ints) + elif attr_proto.strings: + str_list = list(attr_proto.strings) + if IS_PYTHON3: + str_list = list(map(lambda x: str(x, "utf-8"), str_list)) + return str_list + elif attr_proto.HasField("sparse_tensor"): + return attr_proto.sparse_tensor + else: + raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto)) + def translate_onnx(key, val): - return __onnx_attr_translator.get(key, lambda x: x)(val) + return __onnx_attr_translator.get(key, lambda x: x)(val) + def onnx2seedot(dtype): - return TENSOR_TYPE_TO_SEEDOT_TYPE[_onnx_dtype(dtype)] + return TENSOR_TYPE_TO_SEEDOT_TYPE[_onnx_dtype(dtype)] + def _onnx_dtype(dtype): - if isinstance(dtype, Number): - onnx_dype = dtype - elif isinstance(dtype, str): - onnx_dype = TensorProto.DataType.Value(dtype) - else: - raise RuntimeError("dtype should be number or str.") - return onnx_dype + if isinstance(dtype, Number): + onnx_dype = dtype + elif isinstance(dtype, str): + onnx_dype = TensorProto.DataType.Value(dtype) + else: + raise RuntimeError("dtype should be number or str.") + return onnx_dype + TENSOR_TYPE_TO_SEEDOT_TYPE = { - int(TensorProto.FLOAT): 'float32', - int(TensorProto.UINT8): 'uint8', - int(TensorProto.INT8): 'int8', - int(TensorProto.UINT16): 'uint16', - int(TensorProto.INT16): 'int16', - int(TensorProto.INT32): 'int32', - int(TensorProto.INT64): 'int64', - int(TensorProto.BOOL): 'bool', - int(TensorProto.FLOAT16): 'float16', - int(TensorProto.DOUBLE): 'float64', - int(TensorProto.COMPLEX64): 'complex64', - int(TensorProto.COMPLEX128): 'complex128', - int(TensorProto.UINT32): 'uint32', - int(TensorProto.UINT64): 'uint64', - int(TensorProto.STRING): 'string' + int(TensorProto.FLOAT): "float32", + int(TensorProto.UINT8): "uint8", + int(TensorProto.INT8): "int8", + int(TensorProto.UINT16): "uint16", + int(TensorProto.INT16): "int16", + int(TensorProto.INT32): "int32", + int(TensorProto.INT64): "int64", + int(TensorProto.BOOL): "bool", + int(TensorProto.FLOAT16): "float16", + int(TensorProto.DOUBLE): "float64", + int(TensorProto.COMPLEX64): "complex64", + int(TensorProto.COMPLEX128): "complex128", + int(TensorProto.UINT32): "uint32", + int(TensorProto.UINT64): "uint64", + int(TensorProto.STRING): "string", } + def getOperatorsIdx(token): - #TODO : remove usage of this - return AST.Operators.convSymbolToEnumValue(token) + # TODO : remove usage of this + return AST.Operators.convSymbolToEnumValue(token) + def get_seedot_shape_order(old_shape): - if(len(old_shape) == 4): - # Case when spatial dimension is 2 - # inverse of [1, 3, 4, 2] is [1, 4, 2, 3] - return ([old_shape[0], old_shape[2], old_shape[3], old_shape[1]], [1, 4, 2, 3]) - else: - # Casr when spatial dimension is 3 - # inverse of [1, 3, 4, 5, 2] is [1, 5, 2, 3, 4] - return ([old_shape[0], old_shape[2], old_shape[3], old_shape[4], old_shape[1]], [1, 5, 2, 3, 4]) + if len(old_shape) == 4: + # Case when spatial dimension is 2 + # inverse of [1, 3, 4, 2] is [1, 4, 2, 3] + return ([old_shape[0], old_shape[2], old_shape[3], old_shape[1]], [1, 4, 2, 3]) + else: + # Casr when spatial dimension is 3 + # inverse of [1, 3, 4, 5, 2] is [1, 5, 2, 3, 4] + return ( + [old_shape[0], old_shape[2], old_shape[3], old_shape[4], old_shape[1]], + [1, 5, 2, 3, 4], + ) + def get_seedot_filter_shape_order(filter_shape): - if(len(filter_shape) == 4): - # Case when spatial dimension is 2 - # inverse of [3, 4, 2, 1] is [4, 3, 1, 2] - return ([filter_shape[2], filter_shape[3], filter_shape[1], filter_shape[0]], [4, 3, 1, 2]) - else: - # Casr when spatial dimension is 3 - # inverse of [3, 4, 5, 2, 1] is [5, 4, 1, 2, 3] - return ([filter_shape[2], filter_shape[3], filter_shape[4], filter_shape[1], filter_shape[0]], [5, 4, 1, 2, 3]) + if len(filter_shape) == 4: + # Case when spatial dimension is 2 + # inverse of [3, 4, 2, 1] is [4, 3, 1, 2] + return ( + [filter_shape[2], filter_shape[3], filter_shape[1], filter_shape[0]], + [4, 3, 1, 2], + ) + else: + # Casr when spatial dimension is 3 + # inverse of [3, 4, 5, 2, 1] is [5, 4, 1, 2, 3] + return ( + [ + filter_shape[2], + filter_shape[3], + filter_shape[4], + filter_shape[1], + filter_shape[0], + ], + [5, 4, 1, 2, 3], + ) + def get_onnx_order(onnx_shape): - if(len(onnx_shape) == 4): - # inverse of [1, 4, 2, 3] is [1, 3, 4, 2] - return [1, 3, 4, 2] - else: - # inverse of [1, 5, 2, 3, 4] is [1, 3, 4, 5, 2] - return [1, 3, 4, 5, 2] + if len(onnx_shape) == 4: + # inverse of [1, 4, 2, 3] is [1, 3, 4, 2] + return [1, 3, 4, 2] + else: + # inverse of [1, 5, 2, 3, 4] is [1, 3, 4, 5, 2] + return [1, 3, 4, 5, 2] + def get_reshaped_input_ast(input_name, value_info, node_name_to_out_var_dict): - onnx_input_shape = list(value_info[input_name][1]) - (seedot_input_shape, seedot_input_order) = get_seedot_shape_order(onnx_input_shape) - return AST.Reshape(AST.ID(node_name_to_out_var_dict[input_name]), seedot_input_shape, seedot_input_order) + onnx_input_shape = list(value_info[input_name][1]) + (seedot_input_shape, seedot_input_order) = get_seedot_shape_order(onnx_input_shape) + return AST.Reshape( + AST.ID(node_name_to_out_var_dict[input_name]), + seedot_input_shape, + seedot_input_order, + ) + def get_reshaped_bias_ast(bias_name, value_info, node_name_to_out_var_dict, dim): - if(dim == 2): - return AST.Reshape(AST.ID(node_name_to_out_var_dict[bias_name]), [1, 1, 1, value_info[bias_name][1][0]], None) - else: - return AST.Reshape(AST.ID(node_name_to_out_var_dict[bias_name]), [1, 1, 1, 1, value_info[bias_name][1][0]], None) + if dim == 2: + return AST.Reshape( + AST.ID(node_name_to_out_var_dict[bias_name]), + [1, 1, 1, value_info[bias_name][1][0]], + None, + ) + else: + return AST.Reshape( + AST.ID(node_name_to_out_var_dict[bias_name]), + [1, 1, 1, 1, value_info[bias_name][1][0]], + None, + ) + def get_reshaped_filter_ast(filter_name, value_info, node_name_to_out_var_dict): - onnx_filter_shape = list(value_info[filter_name][1]) - (seedot_filter_shape, seedot_filter_order) = get_seedot_filter_shape_order(onnx_filter_shape) - return AST.Reshape(AST.ID(node_name_to_out_var_dict[filter_name]), seedot_filter_shape, seedot_filter_order) + onnx_filter_shape = list(value_info[filter_name][1]) + (seedot_filter_shape, seedot_filter_order) = get_seedot_filter_shape_order( + onnx_filter_shape + ) + return AST.Reshape( + AST.ID(node_name_to_out_var_dict[filter_name]), + seedot_filter_shape, + seedot_filter_order, + ) + + +def get_reshaped_output_ast(onnx_output_name, value_info, output_name): + onnx_output_shape = list(value_info[onnx_output_name][1]) + onnx_output_order = get_onnx_order(onnx_output_shape) + return AST.Reshape(AST.ID(output_name), onnx_output_shape, onnx_output_order) -def get_reshaped_output_ast(onnx_output_name, value_info, output_name): - onnx_output_shape = list(value_info[onnx_output_name][1]) - onnx_output_order = get_onnx_order(onnx_output_shape) - return AST.Reshape(AST.ID(output_name), onnx_output_shape, onnx_output_order) def get_new_var_name(out_var_count): - return out_var_prefix + str(out_var_count) - -def update_program_with_new_node(innermost_let_ast_node, new_node, new_node_name, mtdAST): - cur_out_var_ast_node = AST.ID(new_node_name) - new_let_node = AST.Let(cur_out_var_ast_node, new_node, cur_out_var_ast_node) - mtdAST.visit(new_let_node, {AST.ASTNode.mtdKeyTFOpName : 'no', AST.ASTNode.mtdKeyTFNodeName : 'no'}) - # Updating the innermost Let AST node and the expression for previous Let Node - innermost_let_ast_node.expr = new_let_node - innermost_let_ast_node = new_let_node - - # node_name_to_out_var_dict[node.outputs[0]] = new_node_name - return innermost_let_ast_node + return out_var_prefix + str(out_var_count) + + +def update_program_with_new_node( + innermost_let_ast_node, new_node, new_node_name, mtdAST +): + cur_out_var_ast_node = AST.ID(new_node_name) + new_let_node = AST.Let(cur_out_var_ast_node, new_node, cur_out_var_ast_node) + mtdAST.visit( + new_let_node, + {AST.ASTNode.mtdKeyTFOpName: "no", AST.ASTNode.mtdKeyTFNodeName: "no"}, + ) + # Updating the innermost Let AST node and the expression for previous Let Node + innermost_let_ast_node.expr = new_let_node + innermost_let_ast_node = new_let_node + + # node_name_to_out_var_dict[node.outputs[0]] = new_node_name + return innermost_let_ast_node + class ONNXNodesAST: - # value_info: dictionary of name -> (type, dimension tuple) - def Input(node, value_info, node_name_to_out_var_dict, party=0): - if(DEBUG): - print(node.outputs[0]) - # There are two types of inputs - dims = list(node.dims if hasattr(node, 'dims') else ([val.dim_value for val in node.type.tensor_type.shape.dim])) - data_type = node.data_type if hasattr (node, 'data_type') else node.type.tensor_type.elem_type - return AST.Input(dims, onnx2seedot(data_type), inputByParty=party) - - - def Cast(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - if(DEBUG): - print(node) - inputsRef = node.inputs - assert(len(inputsRef) == 1) - # destType = node.attrs['to'] - - # seedot_output_ast = AST.UninterpFuncCall(value_info[node.outputs[0]][1], - # 'Cast', - # [AST.ID(inputsRef[0]), - # AST.ID(destType), - # AST.ID(destType) - # ]) - # output_name = get_new_var_name(out_var_count) - # innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - # out_var_count += 1 - node_name_to_out_var_dict[node.outputs[0]] = inputsRef[0] - - return (innermost_let_ast_node, out_var_count) - - def Pad(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - if(DEBUG): - print(node) - inputsRef = node.inputs - # Skip constant_val input (last input) - inpLen = len(inputsRef) - 1 - assert(inpLen == 2) - inputs = [AST.ID(node_name_to_out_var_dict[inputsRef[x]]) for x in range(0, inpLen)] - mode = node.attrs['mode'] - assert(mode == 'constant') - seedot_output_ast = AST.UninterpFuncCall(list(value_info[node.outputs[0]][1]), - 'PadONNX', inputs) - - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - node_name_to_out_var_dict[node.outputs[0]] = output_name - - return (innermost_let_ast_node, out_var_count) - - def Concat(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - if(DEBUG): - print(node) - inputsRef = node.inputs - N = len(inputsRef) - - inputs = [AST.ID(node_name_to_out_var_dict[inputsRef[x]]) for x in range(0, len(inputsRef))] - axis = node.attrs['axis'] - - seedot_output_ast = AST.UninterpFuncCall(list(value_info[node.outputs[0]][1]), - 'Concat'+str(N) + 'T', - inputs + [AST.Int(axis, 32, False)], - outputDiffInpDims=1 - ) - - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - node_name_to_out_var_dict[node.outputs[0]] = output_name - - return (innermost_let_ast_node, out_var_count) - - def Relu(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - - inputsRef = node.inputs - assert(len(inputsRef)==1) - - - reshaped_input_name = get_new_var_name(out_var_count) - reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) - out_var_count += 1 - - seedot_output_ast = AST.Func(getOperatorsIdx('relu'), AST.ID(reshaped_input_name)) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - reshaped_output_name = get_new_var_name(out_var_count) - onnx_output_ast = get_reshaped_output_ast(node.outputs[0], value_info, output_name) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST) - out_var_count += 1 - node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name - - if(DEBUG): - print(node.outputs[0]) - print(onnx_input_shape, '->', seedot_input_shape, '->', onnx_output_shape) - - return (innermost_let_ast_node, out_var_count) - # return AST.Func(getOperatorsIdx('relu'), AST.ID(node_name_to_out_var_dict[inputsRef[0]])) - - def Add(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - if(DEBUG): - print(node) - inputsRef = node.inputs - assert(len(inputsRef) == 2) - - reshaped_input_name = get_new_var_name(out_var_count) - reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) - out_var_count += 1 - - reshaped_input_name1 = get_new_var_name(out_var_count) - reshaped_input1 = get_reshaped_input_ast(inputsRef[1], value_info, node_name_to_out_var_dict) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input1, reshaped_input_name1, mtdAST) - out_var_count += 1 - - seedot_output_ast = AST.BOp(AST.ID(reshaped_input_name), - getOperatorsIdx('+'), - AST.ID(reshaped_input_name1) - ) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - - reshaped_output_name = get_new_var_name(out_var_count) - onnx_output_ast = get_reshaped_output_ast(node.outputs[0], value_info, output_name) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST) - out_var_count += 1 - node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name - - if(DEBUG): - print(node.outputs[0]) - print(onnx_input_shape, onnx_input_shape1, '->', seedot_input_shape, seedot_input_shape1, '->', onnx_output_shape) - - return (innermost_let_ast_node, out_var_count) - - - def Gemm(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - if(DEBUG): - print(node) - inputsRef = node.inputs - assert(len(inputsRef) == 3) - input1AST = AST.ID(node_name_to_out_var_dict[inputsRef[0]]) - input2AST = AST.ID(node_name_to_out_var_dict[inputsRef[1]]) - - if('transA' in node.attrs and node.attrs['transA']): input1AST = AST.Transpose(input1AST) - if('transB' in node.attrs and node.attrs['transB']): input2AST = AST.Transpose(input2AST) - - # W*x + b - seedot_output_ast = AST.BOp(AST.BOp(input1AST, getOperatorsIdx('*'), input2AST), getOperatorsIdx('+'), AST.ID(node_name_to_out_var_dict[inputsRef[2]])) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - node_name_to_out_var_dict[node.outputs[0]] = output_name - - return (innermost_let_ast_node, out_var_count) - - def Constant(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - if(DEBUG): - print(node) - # TODO: Use AST.decl for defining a tensor. If used as a parameter for Reshape then we don't need it for now. - return (innermost_let_ast_node, out_var_count) - - def Transpose(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - if(DEBUG): - print(node) - - inputsRef = node.inputs - assert(len(inputsRef)==1) - - seedot_output_ast = AST.Transpose(AST.ID(node_name_to_out_var_dict[inputsRef[0]]), node.attrs['perm']) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - node_name_to_out_var_dict[node.outputs[0]] = output_name - - return (innermost_let_ast_node, out_var_count) - - # Only supports split into equal parts - def Split(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - inputsRef = node.inputs - output_count = len(node.outputs) - - for cur_count in range(output_count): - seedot_output_ast = AST.UninterpFuncCall(list(value_info[node.outputs[cur_count]][1]), 'Split', - [AST.ID(node_name_to_out_var_dict[inputsRef[0]]), AST.Int(node.attrs['axis'], 32, False), AST.Int(cur_count, 32, False), AST.Int(output_count, 32, False)]) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - node_name_to_out_var_dict[node.outputs[cur_count]] = output_name - - return (innermost_let_ast_node, out_var_count) - - def ReduceMean(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - inputsRef = node.inputs - - keepdims = node.attrs['keepdims'] - axes = node.attrs['axes'] - - # currently handling only this case - # currently support only 0 case - assert(keepdims == 0) - assert(len(axes) == 2) - - seedot_output_ast = AST.UninterpFuncCall(value_info[node.outputs[0]][1], 'ReduceMeanONNX', - [AST.ID(node_name_to_out_var_dict[inputsRef[0]]), AST.Int(axes[0], 32, False), AST.Int(axes[1], 32, False)]) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - node_name_to_out_var_dict[node.outputs[0]] = output_name - return (innermost_let_ast_node, out_var_count) - - def BatchNormalization(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - - inputsRef = node.inputs - # Are running mean and var used for something? - assert(len(inputsRef)==5) - - reshaped_input_name = get_new_var_name(out_var_count) - reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) - out_var_count += 1 - - seedot_output_ast = AST.FusedBatchNorm(AST.ID(reshaped_input_name), - AST.ID(node_name_to_out_var_dict[inputsRef[1]]), - AST.ID(node_name_to_out_var_dict[inputsRef[2]]), - ) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - reshaped_output_name = get_new_var_name(out_var_count) - onnx_output_ast = get_reshaped_output_ast(node.outputs[0], value_info, output_name) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST) - out_var_count += 1 - node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name - - if(DEBUG): - print(node.outputs[0]) - print(onnx_input_shape, '->', seedot_input_shape, '->', onnx_output_shape) - - return (innermost_let_ast_node, out_var_count) - - def Reshape(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - if(DEBUG): - print(node) - - inputsRef = node.inputs - assert(len(inputsRef)==2) - # print(list(value_info[node.outputs[0]][1])) - - seedot_output_ast = AST.Reshape(AST.ID(node_name_to_out_var_dict[inputsRef[0]]), list(value_info[node.outputs[0]][1]), None) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - node_name_to_out_var_dict[node.outputs[0]] = output_name - - return (innermost_let_ast_node, out_var_count) - - def Flatten(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - if(DEBUG): - print(node) - - inputsRef = node.inputs - assert(len(inputsRef)==1) - - seedot_output_ast = AST.Reshape(AST.ID(node_name_to_out_var_dict[inputsRef[0]]), list(value_info[node.outputs[0]][1]), None) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - node_name_to_out_var_dict[node.outputs[0]] = output_name - - return (innermost_let_ast_node, out_var_count) - - def Conv(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - if(DEBUG): - print(node) - - inputsRef = node.inputs - # since two dimensions represent N: Number of batches and CI: Input channel - inputShape = value_info[inputsRef[0]][1] - spatial_size = len(inputShape)-2 - - if spatial_size == 2: - (innermost_let_ast_node, out_var_count, output_name) = ONNXNodesAST.conv2d(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST) - elif spatial_size == 3: - (innermost_let_ast_node, out_var_count, output_name) = ONNXNodesAST.conv3d(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST) - - reshaped_output_name = get_new_var_name(out_var_count) - onnx_output_ast = get_reshaped_output_ast(node.outputs[0],value_info, output_name) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST) - out_var_count += 1 - node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name - - return (innermost_let_ast_node, out_var_count) - - def conv2d(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - inputsRef = node.inputs - inputShape = value_info[inputsRef[0]][1] - filterShape = value_info[inputsRef[1]][1] - - stridesUsed = node.attrs['strides'] - - assert(len(inputsRef)==2 or len(inputsRef)==3) - assert(len(stridesUsed)==2) - assert(value_info[node.inputs[1]][1][2:] == tuple(node.attrs['kernel_shape'])) - - group = node.attrs['group'] if 'group' in node.attrs else 1 - [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = node.attrs['pads'] if 'pads' in node.attrs else [0,0,0,0] - # we assume VALID case when the padding is in string format - - options = {} - options[AST.PaddingKeysDict.FH] = filterShape[2] - options[AST.PaddingKeysDict.FW] = filterShape[3] - options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft - options[AST.PaddingKeysDict.zPadHRight] = zPadHRight - options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft - options[AST.PaddingKeysDict.zPadWRight] = zPadWRight - options[AST.PaddingKeysDict.strideH] = stridesUsed[0] - options[AST.PaddingKeysDict.strideW] = stridesUsed[1] - options[AST.PaddingKeysDict.ConvDim] = 2 - options[AST.PaddingKeysDict.group] = group - - # print(inputShape, filterShape) - assert (inputShape[1] == filterShape[1]*group) - # For Input: - # [N, CI, H, W] is the Onnx order it should be changed to - # [N, H, W, CI] order - reshaped_input_name = get_new_var_name(out_var_count) - reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) - out_var_count += 1 - - # For filter: - # [CO, CI1, FH, FW] is the Onnx order it should be changed to - # [FH, FW, CI1, CO] order - reshaped_filter_name = get_new_var_name(out_var_count) - reshaped_filter = get_reshaped_filter_ast(inputsRef[1], value_info, node_name_to_out_var_dict) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_filter, reshaped_filter_name, mtdAST) - out_var_count += 1 - - seedot_output_ast = AST.BOp(AST.ID(reshaped_input_name), getOperatorsIdx('#'), AST.ID(reshaped_filter_name), options) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - # If there is bias to be added then reshape and add it - if (len(inputsRef) == 3): - reshaped_bias_name = get_new_var_name(out_var_count) - reshaped_bias = get_reshaped_bias_ast(inputsRef[2], value_info, node_name_to_out_var_dict, 2) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_bias, reshaped_bias_name, mtdAST) - out_var_count += 1 - - seedot_output_ast = AST.BOp(AST.ID(output_name), getOperatorsIdx('+'), AST.ID(reshaped_bias_name), options) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - return (innermost_let_ast_node, out_var_count, output_name) - - def conv3d(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - inputsRef = node.inputs - inputShape = value_info[inputsRef[0]][1] - filterShape = value_info[inputsRef[1]][1] - stridesUsed = node.attrs['strides'] - - assert(len(inputsRef)==2 or len(inputsRef)==3) - assert(len(stridesUsed)==3) - assert(value_info[node.inputs[1]][1][2:] == tuple(node.attrs['kernel_shape'])) - # verify this order - [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = node.attrs['pads'] - - options = {} - options[AST.PaddingKeysDict.FD] = filterShape[2] - options[AST.PaddingKeysDict.FH] = filterShape[3] - options[AST.PaddingKeysDict.FW] = filterShape[4] - options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft - options[AST.PaddingKeysDict.zPadDRight] = zPadDRight - options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft - options[AST.PaddingKeysDict.zPadHRight] = zPadHRight - options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft - options[AST.PaddingKeysDict.zPadWRight] = zPadWRight - options[AST.PaddingKeysDict.strideD] = stridesUsed[0] - options[AST.PaddingKeysDict.strideH] = stridesUsed[1] - options[AST.PaddingKeysDict.strideW] = stridesUsed[2] - options[AST.PaddingKeysDict.ConvDim] = 3 - - assert (inputShape[1] == filterShape[1]) - # For Input: - # [N, CI, D, H, W] is the Onnx order it should be changed to - # [N, D, H, W, CI] order - reshaped_input_name = get_new_var_name(out_var_count) - reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) - out_var_count += 1 - - # For filter: - # [CO, CI1, FD, FH, FW] is the Onnx order it should be changed to - # [FD, FH, FW, CI1, CO] order - reshaped_filter_name = get_new_var_name(out_var_count) - reshaped_filter = get_reshaped_filter_ast(inputsRef[1], value_info, node_name_to_out_var_dict) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_filter, reshaped_filter_name, mtdAST) - out_var_count += 1 - - seedot_output_ast = AST.BOp(AST.ID(reshaped_input_name), getOperatorsIdx('#'), AST.ID(reshaped_filter_name), options) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - # If there is bias to be added then reshape and add it - if (len(inputsRef) == 3): - reshaped_bias_name = get_new_var_name(out_var_count) - reshaped_bias = get_reshaped_bias_ast(inputsRef[2], value_info, node_name_to_out_var_dict, 3) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_bias, reshaped_bias_name, mtdAST) - out_var_count += 1 - - seedot_output_ast = AST.BOp(AST.ID(output_name), getOperatorsIdx('+'), AST.ID(reshaped_bias_name), options) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - return (innermost_let_ast_node, out_var_count, output_name) - - def MaxPool(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - return ONNXNodesAST.helper_processPool(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST, 'MAXPOOL') - - def AvgPool(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - return ONNXNodesAST.helper_processPool(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST, 'AVGPOOL') - - def GlobalAveragePool(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - if(DEBUG): - print(node) - inputsRef = node.inputs - assert(len(inputsRef)==1) - - reshaped_input_name = get_new_var_name(out_var_count) - reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) - out_var_count += 1 - - seedot_output_ast = AST.Pool(AST.Pool.PoolType.AvgPool, - AST.ID(reshaped_input_name), - { - AST.PaddingKeysDict.FH: value_info[inputsRef[0]][1][2], - AST.PaddingKeysDict.FW: value_info[inputsRef[0]][1][3], - AST.PaddingKeysDict.zPadHLeft: 0, - AST.PaddingKeysDict.zPadHRight: 0, - AST.PaddingKeysDict.zPadWLeft: 0, - AST.PaddingKeysDict.zPadWRight: 0, - AST.PaddingKeysDict.strideH: 1, - AST.PaddingKeysDict.strideW: 1 - } - ) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - reshaped_output_name = get_new_var_name(out_var_count) - onnx_output_ast = get_reshaped_output_ast(node.outputs[0], value_info, output_name) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST) - out_var_count += 1 - node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name - - return (innermost_let_ast_node, out_var_count) - - def helper_processPool(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST, typeOfPool): - node = OnnxNode(node) - if(DEBUG): - print(node) - inputsRef = node.inputs - assert(len(inputsRef)==1) - - stridesUsed = node.attrs['strides'] - strideH = stridesUsed[0] - strideW = stridesUsed[1] - - kSizeUsed = node.attrs['kernel_shape'] - # assert((kSizeUsed[0] == 1) and (kSizeUsed[3] == 1)) - kSizeH = kSizeUsed[0] - kSizeW = kSizeUsed[1] - - inputShape = value_info[inputsRef[0]][1] - # print(inputShape) - imgH = inputShape[2] - imgW = inputShape[3] - - # verify order - [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = node.attrs['pads'] - - - reshaped_input_name = get_new_var_name(out_var_count) - reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) - out_var_count += 1 - - poolType = None - if typeOfPool=='MAXPOOL': poolType = AST.Pool.PoolType.MaxPool - elif typeOfPool=='AVGPOOL': poolType = AST.Pool.PoolType.AvgPool - else: - print("Unknown type of pooling layer.", file=sys.stderr) - assert(False) - seedot_output_ast = AST.Pool(poolType, - AST.ID(reshaped_input_name), - { - AST.PaddingKeysDict.FH: kSizeH, - AST.PaddingKeysDict.FW: kSizeW, - AST.PaddingKeysDict.zPadHLeft: zPadHLeft, - AST.PaddingKeysDict.zPadHRight: zPadHRight, - AST.PaddingKeysDict.zPadWLeft: zPadWLeft, - AST.PaddingKeysDict.zPadWRight: zPadWRight, - AST.PaddingKeysDict.strideH: strideH, - AST.PaddingKeysDict.strideW: strideW - } - ) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - - reshaped_output_name = get_new_var_name(out_var_count) - onnx_output_ast = get_reshaped_output_ast(node.outputs[0], value_info, output_name) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST) - out_var_count += 1 - node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name - - return (innermost_let_ast_node, out_var_count) - - def ConvTranspose(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - node = OnnxNode(node) - if(DEBUG): - print(node) - - inputsRef = node.inputs - # since two dimensions represent N: Number of batches and CI: Input channel - inputShape = value_info[inputsRef[0]][1] - spatial_size = len(inputShape)-2 - if spatial_size == 2: - (innermost_let_ast_node, out_var_count, output_name) = ONNXNodesAST.conv2dtranspose(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST) - elif spatial_size == 3: - (innermost_let_ast_node, out_var_count, output_name) = ONNXNodesAST.conv3dtranspose(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST) - - reshaped_output_name = get_new_var_name(out_var_count) - onnx_output_ast = get_reshaped_output_ast(node.outputs[0],value_info, output_name) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST) - out_var_count += 1 - node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name - - return (innermost_let_ast_node, out_var_count) - - def conv2dtranspose(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - inputsRef = node.inputs - inputShape = value_info[inputsRef[0]][1] - filterShape = value_info[inputsRef[1]][1] - stridesUsed = node.attrs['strides'] - outputShape = value_info[node.outputs[0]][1] - - # sometimes there is a bias to be added as well - assert(len(inputsRef)==2 or len(inputsRef)==3) - assert(len(stridesUsed)==2) - assert(value_info[node.inputs[1]][1][2:] == tuple(node.attrs['kernel_shape'])) - # verify this order - [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = node.attrs['pads'] - - options = {} - options[AST.PaddingKeysDict.FH] = filterShape[2] - options[AST.PaddingKeysDict.FW] = filterShape[3] - options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft - options[AST.PaddingKeysDict.zPadHRight] = zPadHRight - options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft - options[AST.PaddingKeysDict.zPadWRight] = zPadWRight - options[AST.PaddingKeysDict.strideH] = stridesUsed[0] - options[AST.PaddingKeysDict.strideW] = stridesUsed[1] - options[AST.PaddingKeysDict.ConvDim] = 2 - options[AST.PaddingKeysDict.outputImgH] = outputShape[2] - options[AST.PaddingKeysDict.outputImgW] = outputShape[3] - - assert (inputShape[1] == filterShape[0]) - # For Input: - # [N, CI, H, W] is the Onnx order it should be changed to - # [N, H, W, CI] order - - reshaped_input_name = get_new_var_name(out_var_count) - reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) - out_var_count += 1 - # For filter: - # [CI, CO, FH, FW] is the Onnx order it should be changed to - # [FH, FW, CI1, CO] order - reshaped_filter_name = get_new_var_name(out_var_count) - reshaped_filter = get_reshaped_filter_ast(inputsRef[1], value_info, node_name_to_out_var_dict) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_filter, reshaped_filter_name, mtdAST) - out_var_count += 1 - - seedot_output_ast = AST.BOp(AST.ID(reshaped_input_name), getOperatorsIdx('#T'), AST.ID(reshaped_filter_name), options) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - # If there is bias to be added then reshape and add it - if (len(inputsRef) == 3): - biasShape = value_info[inputsRef[2]][1] - reshaped_bias_name = get_new_var_name(out_var_count) - reshaped_bias = get_reshaped_bias_ast(inputsRef[2], value_info, node_name_to_out_var_dict, 2) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_bias, reshaped_bias_name, mtdAST) - out_var_count += 1 - - seedot_output_ast = AST.BOp(AST.ID(output_name), getOperatorsIdx('+'), AST.ID(reshaped_bias_name), options) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - return (innermost_let_ast_node, out_var_count, output_name) - - def conv3dtranspose(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): - inputsRef = node.inputs - inputShape = value_info[inputsRef[0]][1] - filterShape = value_info[inputsRef[1]][1] - stridesUsed = node.attrs['strides'] - outputShape = value_info[node.outputs[0]][1] - - # sometimes there is a bias to be added as well - assert(len(inputsRef)==2 or len(inputsRef)==3) - assert(len(stridesUsed)==3) - assert(value_info[node.inputs[1]][1][2:] == tuple(node.attrs['kernel_shape'])) - # verify this order - [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = node.attrs['pads'] - - options = {} - options[AST.PaddingKeysDict.FD] = filterShape[2] - options[AST.PaddingKeysDict.FH] = filterShape[3] - options[AST.PaddingKeysDict.FW] = filterShape[4] - options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft - options[AST.PaddingKeysDict.zPadDRight] = zPadDRight - options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft - options[AST.PaddingKeysDict.zPadHRight] = zPadHRight - options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft - options[AST.PaddingKeysDict.zPadWRight] = zPadWRight - options[AST.PaddingKeysDict.strideD] = stridesUsed[0] - options[AST.PaddingKeysDict.strideH] = stridesUsed[1] - options[AST.PaddingKeysDict.strideW] = stridesUsed[2] - options[AST.PaddingKeysDict.ConvDim] = 3 - options[AST.PaddingKeysDict.outputImgD] = outputShape[2] - options[AST.PaddingKeysDict.outputImgH] = outputShape[3] - options[AST.PaddingKeysDict.outputImgW] = outputShape[4] - - assert (inputShape[1] == filterShape[0]) - # For Input: - # [N, CI, D, H, W] is the Onnx order it should be changed to - # [N, D, H, W, CI] order - - reshaped_input_name = get_new_var_name(out_var_count) - reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) - out_var_count += 1 - # For filter: - # [CI, CO, FD, FH, FW] is the Onnx order it should be changed to - # [FD, FH, FW, CI1, CO] order - reshaped_filter_name = get_new_var_name(out_var_count) - reshaped_filter = get_reshaped_filter_ast(inputsRef[1], value_info, node_name_to_out_var_dict) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_filter, reshaped_filter_name, mtdAST) - out_var_count += 1 - - seedot_output_ast = AST.BOp(AST.ID(reshaped_input_name), getOperatorsIdx('#T'), AST.ID(reshaped_filter_name), options) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - # If there is bias to be added then reshape and add it - if (len(inputsRef) == 3): - biasShape = value_info[inputsRef[2]][1] - reshaped_bias_name = get_new_var_name(out_var_count) - reshaped_bias = get_reshaped_bias_ast(inputsRef[2], value_info, node_name_to_out_var_dict, 3) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, reshaped_bias, reshaped_bias_name, mtdAST) - out_var_count += 1 - - seedot_output_ast = AST.BOp(AST.ID(output_name), getOperatorsIdx('+'), AST.ID(reshaped_bias_name), options) - output_name = get_new_var_name(out_var_count) - innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) - out_var_count += 1 - - - - return (innermost_let_ast_node, out_var_count, output_name) - + # value_info: dictionary of name -> (type, dimension tuple) + def Input(node, value_info, node_name_to_out_var_dict, party=0): + if DEBUG: + print(node.outputs[0]) + # There are two types of inputs + dims = list( + node.dims + if hasattr(node, "dims") + else ([val.dim_value for val in node.type.tensor_type.shape.dim]) + ) + data_type = ( + node.data_type + if hasattr(node, "data_type") + else node.type.tensor_type.elem_type + ) + return AST.Input(dims, onnx2seedot(data_type), inputByParty=party) + + def Cast( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + if DEBUG: + print(node) + inputsRef = node.inputs + assert len(inputsRef) == 1 + # destType = node.attrs['to'] + + # seedot_output_ast = AST.UninterpFuncCall(value_info[node.outputs[0]][1], + # 'Cast', + # [AST.ID(inputsRef[0]), + # AST.ID(destType), + # AST.ID(destType) + # ]) + # output_name = get_new_var_name(out_var_count) + # innermost_let_ast_node = update_program_with_new_node(innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) + # out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = inputsRef[0] + + return (innermost_let_ast_node, out_var_count) + + def Pad( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + if DEBUG: + print(node) + inputsRef = node.inputs + # Skip constant_val input (last input) + inpLen = len(inputsRef) - 1 + assert inpLen == 2 + inputs = [ + AST.ID(node_name_to_out_var_dict[inputsRef[x]]) for x in range(0, inpLen) + ] + mode = node.attrs["mode"] + assert mode == "constant" + seedot_output_ast = AST.UninterpFuncCall( + list(value_info[node.outputs[0]][1]), "PadONNX", inputs + ) + + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + node_name_to_out_var_dict[node.outputs[0]] = output_name + + return (innermost_let_ast_node, out_var_count) + + def Concat( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + if DEBUG: + print(node) + inputsRef = node.inputs + N = len(inputsRef) + + inputs = [ + AST.ID(node_name_to_out_var_dict[inputsRef[x]]) + for x in range(0, len(inputsRef)) + ] + axis = node.attrs["axis"] + + seedot_output_ast = AST.UninterpFuncCall( + list(value_info[node.outputs[0]][1]), + "Concat" + str(N) + "T", + inputs + [AST.Int(axis, 32, False)], + outputDiffInpDims=1, + ) + + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + node_name_to_out_var_dict[node.outputs[0]] = output_name + + return (innermost_let_ast_node, out_var_count) + + def Relu( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + + inputsRef = node.inputs + assert len(inputsRef) == 1 + + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast( + inputsRef[0], value_info, node_name_to_out_var_dict + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST + ) + out_var_count += 1 + + seedot_output_ast = AST.Func( + getOperatorsIdx("relu"), AST.ID(reshaped_input_name) + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + reshaped_output_name = get_new_var_name(out_var_count) + onnx_output_ast = get_reshaped_output_ast( + node.outputs[0], value_info, output_name + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST + ) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name + + if DEBUG: + print(node.outputs[0]) + print(onnx_input_shape, "->", seedot_input_shape, "->", onnx_output_shape) + + return (innermost_let_ast_node, out_var_count) + # return AST.Func(getOperatorsIdx('relu'), AST.ID(node_name_to_out_var_dict[inputsRef[0]])) + + def Add( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + if DEBUG: + print(node) + inputsRef = node.inputs + assert len(inputsRef) == 2 + + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast( + inputsRef[0], value_info, node_name_to_out_var_dict + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST + ) + out_var_count += 1 + + reshaped_input_name1 = get_new_var_name(out_var_count) + reshaped_input1 = get_reshaped_input_ast( + inputsRef[1], value_info, node_name_to_out_var_dict + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_input1, reshaped_input_name1, mtdAST + ) + out_var_count += 1 + + seedot_output_ast = AST.BOp( + AST.ID(reshaped_input_name), + getOperatorsIdx("+"), + AST.ID(reshaped_input_name1), + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + reshaped_output_name = get_new_var_name(out_var_count) + onnx_output_ast = get_reshaped_output_ast( + node.outputs[0], value_info, output_name + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST + ) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name + + if DEBUG: + print(node.outputs[0]) + print( + onnx_input_shape, + onnx_input_shape1, + "->", + seedot_input_shape, + seedot_input_shape1, + "->", + onnx_output_shape, + ) + + return (innermost_let_ast_node, out_var_count) + + def Gemm( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + if DEBUG: + print(node) + inputsRef = node.inputs + assert len(inputsRef) == 3 + input1AST = AST.ID(node_name_to_out_var_dict[inputsRef[0]]) + input2AST = AST.ID(node_name_to_out_var_dict[inputsRef[1]]) + + if "transA" in node.attrs and node.attrs["transA"]: + input1AST = AST.Transpose(input1AST) + if "transB" in node.attrs and node.attrs["transB"]: + input2AST = AST.Transpose(input2AST) + + # W*x + b + seedot_output_ast = AST.BOp( + AST.BOp(input1AST, getOperatorsIdx("*"), input2AST), + getOperatorsIdx("+"), + AST.ID(node_name_to_out_var_dict[inputsRef[2]]), + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + node_name_to_out_var_dict[node.outputs[0]] = output_name + + return (innermost_let_ast_node, out_var_count) + + def Constant( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + if DEBUG: + print(node) + # TODO: Use AST.decl for defining a tensor. If used as a parameter for Reshape then we don't need it for now. + return (innermost_let_ast_node, out_var_count) + + def Transpose( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + if DEBUG: + print(node) + + inputsRef = node.inputs + assert len(inputsRef) == 1 + + seedot_output_ast = AST.Transpose( + AST.ID(node_name_to_out_var_dict[inputsRef[0]]), node.attrs["perm"] + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = output_name + + return (innermost_let_ast_node, out_var_count) + + # Only supports split into equal parts + def Split( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + inputsRef = node.inputs + output_count = len(node.outputs) + + for cur_count in range(output_count): + seedot_output_ast = AST.UninterpFuncCall( + list(value_info[node.outputs[cur_count]][1]), + "Split", + [ + AST.ID(node_name_to_out_var_dict[inputsRef[0]]), + AST.Int(node.attrs["axis"], 32, False), + AST.Int(cur_count, 32, False), + AST.Int(output_count, 32, False), + ], + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[cur_count]] = output_name + + return (innermost_let_ast_node, out_var_count) + + def ReduceMean( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + inputsRef = node.inputs + + keepdims = node.attrs["keepdims"] + axes = node.attrs["axes"] + + # currently handling only this case + # currently support only 0 case + assert keepdims == 0 + assert len(axes) == 2 + + seedot_output_ast = AST.UninterpFuncCall( + value_info[node.outputs[0]][1], + "ReduceMeanONNX", + [ + AST.ID(node_name_to_out_var_dict[inputsRef[0]]), + AST.Int(axes[0], 32, False), + AST.Int(axes[1], 32, False), + ], + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = output_name + return (innermost_let_ast_node, out_var_count) + + def BatchNormalization( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + + inputsRef = node.inputs + # Are running mean and var used for something? + assert len(inputsRef) == 5 + + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast( + inputsRef[0], value_info, node_name_to_out_var_dict + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST + ) + out_var_count += 1 + + seedot_output_ast = AST.FusedBatchNorm( + AST.ID(reshaped_input_name), + AST.ID(node_name_to_out_var_dict[inputsRef[1]]), + AST.ID(node_name_to_out_var_dict[inputsRef[2]]), + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + reshaped_output_name = get_new_var_name(out_var_count) + onnx_output_ast = get_reshaped_output_ast( + node.outputs[0], value_info, output_name + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST + ) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name + + if DEBUG: + print(node.outputs[0]) + print(onnx_input_shape, "->", seedot_input_shape, "->", onnx_output_shape) + + return (innermost_let_ast_node, out_var_count) + + def Reshape( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + if DEBUG: + print(node) + + inputsRef = node.inputs + assert len(inputsRef) == 2 + # print(list(value_info[node.outputs[0]][1])) + + seedot_output_ast = AST.Reshape( + AST.ID(node_name_to_out_var_dict[inputsRef[0]]), + list(value_info[node.outputs[0]][1]), + None, + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = output_name + + return (innermost_let_ast_node, out_var_count) + + def Flatten( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + if DEBUG: + print(node) + + inputsRef = node.inputs + assert len(inputsRef) == 1 + + seedot_output_ast = AST.Reshape( + AST.ID(node_name_to_out_var_dict[inputsRef[0]]), + list(value_info[node.outputs[0]][1]), + None, + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = output_name + + return (innermost_let_ast_node, out_var_count) + + def Conv( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + if DEBUG: + print(node) + + inputsRef = node.inputs + # since two dimensions represent N: Number of batches and CI: Input channel + inputShape = value_info[inputsRef[0]][1] + spatial_size = len(inputShape) - 2 + + if spatial_size == 2: + (innermost_let_ast_node, out_var_count, output_name) = ONNXNodesAST.conv2d( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ) + elif spatial_size == 3: + (innermost_let_ast_node, out_var_count, output_name) = ONNXNodesAST.conv3d( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ) + + reshaped_output_name = get_new_var_name(out_var_count) + onnx_output_ast = get_reshaped_output_ast( + node.outputs[0], value_info, output_name + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST + ) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name + + return (innermost_let_ast_node, out_var_count) + + def conv2d( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + inputsRef = node.inputs + inputShape = value_info[inputsRef[0]][1] + filterShape = value_info[inputsRef[1]][1] + + stridesUsed = node.attrs["strides"] + + assert len(inputsRef) == 2 or len(inputsRef) == 3 + assert len(stridesUsed) == 2 + assert value_info[node.inputs[1]][1][2:] == tuple(node.attrs["kernel_shape"]) + + group = node.attrs["group"] if "group" in node.attrs else 1 + [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = ( + node.attrs["pads"] if "pads" in node.attrs else [0, 0, 0, 0] + ) + # we assume VALID case when the padding is in string format + + options = {} + options[AST.PaddingKeysDict.FH] = filterShape[2] + options[AST.PaddingKeysDict.FW] = filterShape[3] + options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft + options[AST.PaddingKeysDict.zPadHRight] = zPadHRight + options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft + options[AST.PaddingKeysDict.zPadWRight] = zPadWRight + options[AST.PaddingKeysDict.strideH] = stridesUsed[0] + options[AST.PaddingKeysDict.strideW] = stridesUsed[1] + options[AST.PaddingKeysDict.ConvDim] = 2 + options[AST.PaddingKeysDict.group] = group + + # print(inputShape, filterShape) + assert inputShape[1] == filterShape[1] * group + # For Input: + # [N, CI, H, W] is the Onnx order it should be changed to + # [N, H, W, CI] order + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast( + inputsRef[0], value_info, node_name_to_out_var_dict + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST + ) + out_var_count += 1 + + # For filter: + # [CO, CI1, FH, FW] is the Onnx order it should be changed to + # [FH, FW, CI1, CO] order + reshaped_filter_name = get_new_var_name(out_var_count) + reshaped_filter = get_reshaped_filter_ast( + inputsRef[1], value_info, node_name_to_out_var_dict + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_filter, reshaped_filter_name, mtdAST + ) + out_var_count += 1 + + seedot_output_ast = AST.BOp( + AST.ID(reshaped_input_name), + getOperatorsIdx("#"), + AST.ID(reshaped_filter_name), + options, + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + # If there is bias to be added then reshape and add it + if len(inputsRef) == 3: + reshaped_bias_name = get_new_var_name(out_var_count) + reshaped_bias = get_reshaped_bias_ast( + inputsRef[2], value_info, node_name_to_out_var_dict, 2 + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_bias, reshaped_bias_name, mtdAST + ) + out_var_count += 1 + + seedot_output_ast = AST.BOp( + AST.ID(output_name), + getOperatorsIdx("+"), + AST.ID(reshaped_bias_name), + options, + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + return (innermost_let_ast_node, out_var_count, output_name) + + def conv3d( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + inputsRef = node.inputs + inputShape = value_info[inputsRef[0]][1] + filterShape = value_info[inputsRef[1]][1] + stridesUsed = node.attrs["strides"] + + assert len(inputsRef) == 2 or len(inputsRef) == 3 + assert len(stridesUsed) == 3 + assert value_info[node.inputs[1]][1][2:] == tuple(node.attrs["kernel_shape"]) + # verify this order + [ + zPadDLeft, + zPadDRight, + zPadHLeft, + zPadHRight, + zPadWLeft, + zPadWRight, + ] = node.attrs["pads"] + + options = {} + options[AST.PaddingKeysDict.FD] = filterShape[2] + options[AST.PaddingKeysDict.FH] = filterShape[3] + options[AST.PaddingKeysDict.FW] = filterShape[4] + options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft + options[AST.PaddingKeysDict.zPadDRight] = zPadDRight + options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft + options[AST.PaddingKeysDict.zPadHRight] = zPadHRight + options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft + options[AST.PaddingKeysDict.zPadWRight] = zPadWRight + options[AST.PaddingKeysDict.strideD] = stridesUsed[0] + options[AST.PaddingKeysDict.strideH] = stridesUsed[1] + options[AST.PaddingKeysDict.strideW] = stridesUsed[2] + options[AST.PaddingKeysDict.ConvDim] = 3 + + assert inputShape[1] == filterShape[1] + # For Input: + # [N, CI, D, H, W] is the Onnx order it should be changed to + # [N, D, H, W, CI] order + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast( + inputsRef[0], value_info, node_name_to_out_var_dict + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST + ) + out_var_count += 1 + + # For filter: + # [CO, CI1, FD, FH, FW] is the Onnx order it should be changed to + # [FD, FH, FW, CI1, CO] order + reshaped_filter_name = get_new_var_name(out_var_count) + reshaped_filter = get_reshaped_filter_ast( + inputsRef[1], value_info, node_name_to_out_var_dict + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_filter, reshaped_filter_name, mtdAST + ) + out_var_count += 1 + + seedot_output_ast = AST.BOp( + AST.ID(reshaped_input_name), + getOperatorsIdx("#"), + AST.ID(reshaped_filter_name), + options, + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + # If there is bias to be added then reshape and add it + if len(inputsRef) == 3: + reshaped_bias_name = get_new_var_name(out_var_count) + reshaped_bias = get_reshaped_bias_ast( + inputsRef[2], value_info, node_name_to_out_var_dict, 3 + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_bias, reshaped_bias_name, mtdAST + ) + out_var_count += 1 + + seedot_output_ast = AST.BOp( + AST.ID(output_name), + getOperatorsIdx("+"), + AST.ID(reshaped_bias_name), + options, + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + return (innermost_let_ast_node, out_var_count, output_name) + + def MaxPool( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + return ONNXNodesAST.helper_processPool( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + "MAXPOOL", + ) + + def AvgPool( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + return ONNXNodesAST.helper_processPool( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + "AVGPOOL", + ) + + def GlobalAveragePool( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + if DEBUG: + print(node) + inputsRef = node.inputs + assert len(inputsRef) == 1 + + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast( + inputsRef[0], value_info, node_name_to_out_var_dict + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST + ) + out_var_count += 1 + + seedot_output_ast = AST.Pool( + AST.Pool.PoolType.AvgPool, + AST.ID(reshaped_input_name), + { + AST.PaddingKeysDict.FH: value_info[inputsRef[0]][1][2], + AST.PaddingKeysDict.FW: value_info[inputsRef[0]][1][3], + AST.PaddingKeysDict.zPadHLeft: 0, + AST.PaddingKeysDict.zPadHRight: 0, + AST.PaddingKeysDict.zPadWLeft: 0, + AST.PaddingKeysDict.zPadWRight: 0, + AST.PaddingKeysDict.strideH: 1, + AST.PaddingKeysDict.strideW: 1, + }, + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + reshaped_output_name = get_new_var_name(out_var_count) + onnx_output_ast = get_reshaped_output_ast( + node.outputs[0], value_info, output_name + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST + ) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name + + return (innermost_let_ast_node, out_var_count) + + def helper_processPool( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + typeOfPool, + ): + node = OnnxNode(node) + if DEBUG: + print(node) + inputsRef = node.inputs + assert len(inputsRef) == 1 + + stridesUsed = node.attrs["strides"] + strideH = stridesUsed[0] + strideW = stridesUsed[1] + + kSizeUsed = node.attrs["kernel_shape"] + # assert((kSizeUsed[0] == 1) and (kSizeUsed[3] == 1)) + kSizeH = kSizeUsed[0] + kSizeW = kSizeUsed[1] + + inputShape = value_info[inputsRef[0]][1] + # print(inputShape) + imgH = inputShape[2] + imgW = inputShape[3] + + # verify order + [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = node.attrs["pads"] + + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast( + inputsRef[0], value_info, node_name_to_out_var_dict + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST + ) + out_var_count += 1 + + poolType = None + if typeOfPool == "MAXPOOL": + poolType = AST.Pool.PoolType.MaxPool + elif typeOfPool == "AVGPOOL": + poolType = AST.Pool.PoolType.AvgPool + else: + print("Unknown type of pooling layer.", file=sys.stderr) + assert False + seedot_output_ast = AST.Pool( + poolType, + AST.ID(reshaped_input_name), + { + AST.PaddingKeysDict.FH: kSizeH, + AST.PaddingKeysDict.FW: kSizeW, + AST.PaddingKeysDict.zPadHLeft: zPadHLeft, + AST.PaddingKeysDict.zPadHRight: zPadHRight, + AST.PaddingKeysDict.zPadWLeft: zPadWLeft, + AST.PaddingKeysDict.zPadWRight: zPadWRight, + AST.PaddingKeysDict.strideH: strideH, + AST.PaddingKeysDict.strideW: strideW, + }, + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + reshaped_output_name = get_new_var_name(out_var_count) + onnx_output_ast = get_reshaped_output_ast( + node.outputs[0], value_info, output_name + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST + ) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name + + return (innermost_let_ast_node, out_var_count) + + def ConvTranspose( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + node = OnnxNode(node) + if DEBUG: + print(node) + + inputsRef = node.inputs + # since two dimensions represent N: Number of batches and CI: Input channel + inputShape = value_info[inputsRef[0]][1] + spatial_size = len(inputShape) - 2 + if spatial_size == 2: + ( + innermost_let_ast_node, + out_var_count, + output_name, + ) = ONNXNodesAST.conv2dtranspose( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ) + elif spatial_size == 3: + ( + innermost_let_ast_node, + out_var_count, + output_name, + ) = ONNXNodesAST.conv3dtranspose( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ) + + reshaped_output_name = get_new_var_name(out_var_count) + onnx_output_ast = get_reshaped_output_ast( + node.outputs[0], value_info, output_name + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST + ) + out_var_count += 1 + node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name + + return (innermost_let_ast_node, out_var_count) + + def conv2dtranspose( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + inputsRef = node.inputs + inputShape = value_info[inputsRef[0]][1] + filterShape = value_info[inputsRef[1]][1] + stridesUsed = node.attrs["strides"] + outputShape = value_info[node.outputs[0]][1] + + # sometimes there is a bias to be added as well + assert len(inputsRef) == 2 or len(inputsRef) == 3 + assert len(stridesUsed) == 2 + assert value_info[node.inputs[1]][1][2:] == tuple(node.attrs["kernel_shape"]) + # verify this order + [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = node.attrs["pads"] + + options = {} + options[AST.PaddingKeysDict.FH] = filterShape[2] + options[AST.PaddingKeysDict.FW] = filterShape[3] + options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft + options[AST.PaddingKeysDict.zPadHRight] = zPadHRight + options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft + options[AST.PaddingKeysDict.zPadWRight] = zPadWRight + options[AST.PaddingKeysDict.strideH] = stridesUsed[0] + options[AST.PaddingKeysDict.strideW] = stridesUsed[1] + options[AST.PaddingKeysDict.ConvDim] = 2 + options[AST.PaddingKeysDict.outputImgH] = outputShape[2] + options[AST.PaddingKeysDict.outputImgW] = outputShape[3] + + assert inputShape[1] == filterShape[0] + # For Input: + # [N, CI, H, W] is the Onnx order it should be changed to + # [N, H, W, CI] order + + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast( + inputsRef[0], value_info, node_name_to_out_var_dict + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST + ) + out_var_count += 1 + # For filter: + # [CI, CO, FH, FW] is the Onnx order it should be changed to + # [FH, FW, CI1, CO] order + reshaped_filter_name = get_new_var_name(out_var_count) + reshaped_filter = get_reshaped_filter_ast( + inputsRef[1], value_info, node_name_to_out_var_dict + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_filter, reshaped_filter_name, mtdAST + ) + out_var_count += 1 + + seedot_output_ast = AST.BOp( + AST.ID(reshaped_input_name), + getOperatorsIdx("#T"), + AST.ID(reshaped_filter_name), + options, + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + # If there is bias to be added then reshape and add it + if len(inputsRef) == 3: + biasShape = value_info[inputsRef[2]][1] + reshaped_bias_name = get_new_var_name(out_var_count) + reshaped_bias = get_reshaped_bias_ast( + inputsRef[2], value_info, node_name_to_out_var_dict, 2 + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_bias, reshaped_bias_name, mtdAST + ) + out_var_count += 1 + + seedot_output_ast = AST.BOp( + AST.ID(output_name), + getOperatorsIdx("+"), + AST.ID(reshaped_bias_name), + options, + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + return (innermost_let_ast_node, out_var_count, output_name) + + def conv3dtranspose( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ): + inputsRef = node.inputs + inputShape = value_info[inputsRef[0]][1] + filterShape = value_info[inputsRef[1]][1] + stridesUsed = node.attrs["strides"] + outputShape = value_info[node.outputs[0]][1] + + # sometimes there is a bias to be added as well + assert len(inputsRef) == 2 or len(inputsRef) == 3 + assert len(stridesUsed) == 3 + assert value_info[node.inputs[1]][1][2:] == tuple(node.attrs["kernel_shape"]) + # verify this order + [ + zPadDLeft, + zPadDRight, + zPadHLeft, + zPadHRight, + zPadWLeft, + zPadWRight, + ] = node.attrs["pads"] + + options = {} + options[AST.PaddingKeysDict.FD] = filterShape[2] + options[AST.PaddingKeysDict.FH] = filterShape[3] + options[AST.PaddingKeysDict.FW] = filterShape[4] + options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft + options[AST.PaddingKeysDict.zPadDRight] = zPadDRight + options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft + options[AST.PaddingKeysDict.zPadHRight] = zPadHRight + options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft + options[AST.PaddingKeysDict.zPadWRight] = zPadWRight + options[AST.PaddingKeysDict.strideD] = stridesUsed[0] + options[AST.PaddingKeysDict.strideH] = stridesUsed[1] + options[AST.PaddingKeysDict.strideW] = stridesUsed[2] + options[AST.PaddingKeysDict.ConvDim] = 3 + options[AST.PaddingKeysDict.outputImgD] = outputShape[2] + options[AST.PaddingKeysDict.outputImgH] = outputShape[3] + options[AST.PaddingKeysDict.outputImgW] = outputShape[4] + + assert inputShape[1] == filterShape[0] + # For Input: + # [N, CI, D, H, W] is the Onnx order it should be changed to + # [N, D, H, W, CI] order + + reshaped_input_name = get_new_var_name(out_var_count) + reshaped_input = get_reshaped_input_ast( + inputsRef[0], value_info, node_name_to_out_var_dict + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST + ) + out_var_count += 1 + # For filter: + # [CI, CO, FD, FH, FW] is the Onnx order it should be changed to + # [FD, FH, FW, CI1, CO] order + reshaped_filter_name = get_new_var_name(out_var_count) + reshaped_filter = get_reshaped_filter_ast( + inputsRef[1], value_info, node_name_to_out_var_dict + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_filter, reshaped_filter_name, mtdAST + ) + out_var_count += 1 + + seedot_output_ast = AST.BOp( + AST.ID(reshaped_input_name), + getOperatorsIdx("#T"), + AST.ID(reshaped_filter_name), + options, + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + # If there is bias to be added then reshape and add it + if len(inputsRef) == 3: + biasShape = value_info[inputsRef[2]][1] + reshaped_bias_name = get_new_var_name(out_var_count) + reshaped_bias = get_reshaped_bias_ast( + inputsRef[2], value_info, node_name_to_out_var_dict, 3 + ) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, reshaped_bias, reshaped_bias_name, mtdAST + ) + out_var_count += 1 + + seedot_output_ast = AST.BOp( + AST.ID(output_name), + getOperatorsIdx("+"), + AST.ID(reshaped_bias_name), + options, + ) + output_name = get_new_var_name(out_var_count) + innermost_let_ast_node = update_program_with_new_node( + innermost_let_ast_node, seedot_output_ast, output_name, mtdAST + ) + out_var_count += 1 + + return (innermost_let_ast_node, out_var_count, output_name) diff --git a/Athos/ONNXCompiler/common.py b/Athos/ONNXCompiler/common.py index f343cb5e..46dfd917 100644 --- a/Athos/ONNXCompiler/common.py +++ b/Athos/ONNXCompiler/common.py @@ -1,5 +1,4 @@ - -''' +""" Authors: Shubham Ugare. @@ -21,89 +20,106 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import numpy import os import _pickle as pickle import re + def proto_val_to_dimension_tuple(proto_val): - return tuple([dim.dim_value for dim in proto_val.type.tensor_type.shape.dim]) + return tuple([dim.dim_value for dim in proto_val.type.tensor_type.shape.dim]) + def numpy_float_array_to_fixed_point_val_str(input_array, scale): - cnt = 0 - chunk = '' - for val in numpy.nditer(input_array): - val = int(val*(2**scale)) - chunk += str(val) + '\n' - cnt += 1 - return (chunk, cnt) + cnt = 0 + chunk = "" + for val in numpy.nditer(input_array): + val = int(val * (2 ** scale)) + chunk += str(val) + "\n" + cnt += 1 + return (chunk, cnt) + def numpy_float_array_to_float_val_str(input_array): - chunk = '' - for val in numpy.nditer(input_array): - chunk += str(val) + '\n' - return chunk + chunk = "" + for val in numpy.nditer(input_array): + chunk += str(val) + "\n" + return chunk + def write_debug_info(node_name_to_out_var_dict): - if not os.path.exists('debug'): - os.makedirs('debug') + if not os.path.exists("debug"): + os.makedirs("debug") - with open('debug/onnx_seedot_name_map.pkl', 'wb') as f: - pickle.dump(node_name_to_out_var_dict, f) + with open("debug/onnx_seedot_name_map.pkl", "wb") as f: + pickle.dump(node_name_to_out_var_dict, f) - with open('debug/onnx_seedot_name_map.txt', 'w') as f: - for val in node_name_to_out_var_dict: - f.write(val + ' ' + node_name_to_out_var_dict[val] + '\n') + with open("debug/onnx_seedot_name_map.txt", "w") as f: + for val in node_name_to_out_var_dict: + f.write(val + " " + node_name_to_out_var_dict[val] + "\n") def merge_name_map(): - onnx_seedot_name_map = pickle.load(open('debug/onnx_seedot_name_map.pkl', 'rb')) - seedot_ezpc_name_map = pickle.load(open('debug/seedot_ezpc_name_map.pkl', 'rb')) + onnx_seedot_name_map = pickle.load(open("debug/onnx_seedot_name_map.pkl", "rb")) + seedot_ezpc_name_map = pickle.load(open("debug/seedot_ezpc_name_map.pkl", "rb")) + + with open("debug/onnx_ezpc_name_map.txt", "w") as f: + for val in onnx_seedot_name_map: + f.write(val + " " + seedot_ezpc_name_map[onnx_seedot_name_map[val]]) - with open('debug/onnx_ezpc_name_map.txt', 'w') as f: - for val in onnx_seedot_name_map: - f.write(val + ' ' + seedot_ezpc_name_map[onnx_seedot_name_map[val]]) def get_seedot_name_from_onnx_name(onnx_name): - onnx_seedot_name_map = pickle.load(open('debug/onnx_seedot_name_map.pkl', 'rb')) - print(onnx_seedot_name_map[onnx_name]) + onnx_seedot_name_map = pickle.load(open("debug/onnx_seedot_name_map.pkl", "rb")) + print(onnx_seedot_name_map[onnx_name]) + def parse_output(scale): - f = open('debug/cpp_output_raw.txt', 'r') - g = open('debug/cpp_output.txt', 'w') - chunk = '' - for line in f: - if line.rstrip().replace('-','0').isdigit(): - val = float(line.rstrip()) - val = val/(2**scale) - chunk += str(val) + '\n' - g.write(chunk) - g.close() + f = open("debug/cpp_output_raw.txt", "r") + g = open("debug/cpp_output.txt", "w") + chunk = "" + for line in f: + if line.rstrip().replace("-", "0").isdigit(): + val = float(line.rstrip()) + val = val / (2 ** scale) + chunk += str(val) + "\n" + g.write(chunk) + g.close() + def extract_txt_to_numpy_array(file): - f = open(file, 'r') - op = [float(line.rstrip()) for line in f] - f.close() - return numpy.array(op, dtype=numpy.float32) + f = open(file, "r") + op = [float(line.rstrip()) for line in f] + f.close() + return numpy.array(op, dtype=numpy.float32) + def match_debug(decimal=4): - a = extract_txt_to_numpy_array('debug/onnx_debug.txt') - b = extract_txt_to_numpy_array('debug/cpp_output.txt') - numpy.testing.assert_almost_equal(a, b, decimal) + a = extract_txt_to_numpy_array("debug/onnx_debug.txt") + b = extract_txt_to_numpy_array("debug/cpp_output.txt") + numpy.testing.assert_almost_equal(a, b, decimal) + def match_output(decimal=4): - a = extract_txt_to_numpy_array('debug/onnx_output.txt') - b = extract_txt_to_numpy_array('debug/cpp_output.txt') - numpy.testing.assert_almost_equal(a, b, decimal) - -def add_openmp_threading_to_convolution(file): - with open(file, 'r+') as f: - newfilename = file[:-5]+'1.cpp' - g = open(newfilename, 'w') - content = f.read() - content1 = re.sub('void Conv3DLoopInner\(.*','\g<0> \n #pragma omp parallel for collapse(5) ', content) - content2 = re.sub('void ConvTranspose3DLoopInner\(.*','\g<0> \n #pragma omp parallel for collapse(5) ', content1) - g.write(content2) - g.close() + a = extract_txt_to_numpy_array("debug/onnx_output.txt") + b = extract_txt_to_numpy_array("debug/cpp_output.txt") + numpy.testing.assert_almost_equal(a, b, decimal) + +def add_openmp_threading_to_convolution(file): + with open(file, "r+") as f: + newfilename = file[:-5] + "1.cpp" + g = open(newfilename, "w") + content = f.read() + content1 = re.sub( + "void Conv3DLoopInner\(.*", + "\g<0> \n #pragma omp parallel for collapse(5) ", + content, + ) + content2 = re.sub( + "void ConvTranspose3DLoopInner\(.*", + "\g<0> \n #pragma omp parallel for collapse(5) ", + content1, + ) + g.write(content2) + g.close() diff --git a/Athos/ONNXCompiler/create_input.py b/Athos/ONNXCompiler/create_input.py index 81717f56..2afca00d 100644 --- a/Athos/ONNXCompiler/create_input.py +++ b/Athos/ONNXCompiler/create_input.py @@ -1,5 +1,4 @@ - -''' +""" Authors: Shubham Ugare. @@ -21,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import numpy.random import numpy as np @@ -32,83 +31,101 @@ import math from onnx import numpy_helper + def main(): - if (len(sys.argv) < 3): - print("Model file or scaling factor unspecified.", file=sys.stderr) - exit(1) - - file_name = sys.argv[1] - scaling_factor = int(sys.argv[2]) - file_path = 'models/' + file_name - model_name = file_name[:-5] # name without the '.onnx' extension - model = onnx.load(file_path) - graph_def = model.graph - - # Generating input - input_dims = common.proto_val_to_dimension_tuple(model.graph.input[0]) - input_array = numpy.random.random(input_dims) - # input_array = numpy.ones(input_dims, dtype=float) - print('Generated random input of dimension ' + str(input_dims)) - np.save('debug/' + model_name + '/' + model_name + '_input', input_array) - - (chunk, cnt) = common.numpy_float_array_to_fixed_point_val_str(input_array, scaling_factor) - f = open('debug/' + model_name + '/' + model_name + '_input.inp', 'w') - f.write(chunk) - f.close() - - model_name_to_val_dict = { init_vals.name: numpy_helper.to_array(init_vals).tolist() for init_vals in model.graph.initializer} - - preprocess_batch_normalization(graph_def, model_name_to_val_dict) - - chunk_n = '' - cnt_n = 0 - for init_vals in model.graph.initializer: - (chunk_1, cnt_1) = common.numpy_float_array_to_fixed_point_val_str( - np.asarray(model_name_to_val_dict[init_vals.name], dtype=np.float32), scaling_factor) - chunk_n += chunk_1 - cnt_n += cnt_1 - - f = open('debug/' + model_name + '/' + model_name + '_weights.inp', 'w') - f.write(chunk_n) - f.close() - - f = open('debug/' + model_name + '/' + model_name + '_combined_input_weights.inp', 'w') - f.write(chunk + chunk_n) - f.close() - - print('Total ' + str(cnt + cnt_n) + ' integers were written in ' + model_name + '_combined_input_weights.inp') + if len(sys.argv) < 3: + print("Model file or scaling factor unspecified.", file=sys.stderr) + exit(1) + + file_name = sys.argv[1] + scaling_factor = int(sys.argv[2]) + file_path = "models/" + file_name + model_name = file_name[:-5] # name without the '.onnx' extension + model = onnx.load(file_path) + graph_def = model.graph + + # Generating input + input_dims = common.proto_val_to_dimension_tuple(model.graph.input[0]) + input_array = numpy.random.random(input_dims) + # input_array = numpy.ones(input_dims, dtype=float) + print("Generated random input of dimension " + str(input_dims)) + np.save("debug/" + model_name + "/" + model_name + "_input", input_array) + + (chunk, cnt) = common.numpy_float_array_to_fixed_point_val_str( + input_array, scaling_factor + ) + f = open("debug/" + model_name + "/" + model_name + "_input.inp", "w") + f.write(chunk) + f.close() + + model_name_to_val_dict = { + init_vals.name: numpy_helper.to_array(init_vals).tolist() + for init_vals in model.graph.initializer + } + + preprocess_batch_normalization(graph_def, model_name_to_val_dict) + + chunk_n = "" + cnt_n = 0 + for init_vals in model.graph.initializer: + (chunk_1, cnt_1) = common.numpy_float_array_to_fixed_point_val_str( + np.asarray(model_name_to_val_dict[init_vals.name], dtype=np.float32), + scaling_factor, + ) + chunk_n += chunk_1 + cnt_n += cnt_1 + + f = open("debug/" + model_name + "/" + model_name + "_weights.inp", "w") + f.write(chunk_n) + f.close() + + f = open( + "debug/" + model_name + "/" + model_name + "_combined_input_weights.inp", "w" + ) + f.write(chunk + chunk_n) + f.close() + + print( + "Total " + + str(cnt + cnt_n) + + " integers were written in " + + model_name + + "_combined_input_weights.inp" + ) + def preprocess_batch_normalization(graph_def, model_name_to_val_dict): - # set names to graph nodes if not present - for node in graph_def.node: - node.name = node.output[0] - # Update the batch normalization scale and B - # so that mean and var are not required - if(node.op_type == 'BatchNormalization'): - # scale - gamma = model_name_to_val_dict[node.input[1]] - # B - beta = model_name_to_val_dict[node.input[2]] - mean = model_name_to_val_dict[node.input[3]] - var = model_name_to_val_dict[node.input[4]] - for i in range(len(gamma)): - rsigma = 1/math.sqrt(var[i]+1e-5) - gamma[i] = gamma[i]*rsigma - beta[i] = beta[i]-gamma[i]*mean[i] - mean[i] = 0 - var[i] = 1-1e-5 - - # Just testing if the correct values are put - model_name_to_val_dict2 = {} - for init_vals in graph_def.initializer: - # TODO: Remove float_data - model_name_to_val_dict2[init_vals.name] = init_vals.float_data - for node in graph_def.node: - node.name = node.output[0] - if(node.op_type == 'BatchNormalization'): - mean = model_name_to_val_dict[node.input[3]] - for val in mean: - assert(val == 0) + # set names to graph nodes if not present + for node in graph_def.node: + node.name = node.output[0] + # Update the batch normalization scale and B + # so that mean and var are not required + if node.op_type == "BatchNormalization": + # scale + gamma = model_name_to_val_dict[node.input[1]] + # B + beta = model_name_to_val_dict[node.input[2]] + mean = model_name_to_val_dict[node.input[3]] + var = model_name_to_val_dict[node.input[4]] + for i in range(len(gamma)): + rsigma = 1 / math.sqrt(var[i] + 1e-5) + gamma[i] = gamma[i] * rsigma + beta[i] = beta[i] - gamma[i] * mean[i] + mean[i] = 0 + var[i] = 1 - 1e-5 + + # Just testing if the correct values are put + model_name_to_val_dict2 = {} + for init_vals in graph_def.initializer: + # TODO: Remove float_data + model_name_to_val_dict2[init_vals.name] = init_vals.float_data + for node in graph_def.node: + node.name = node.output[0] + if node.op_type == "BatchNormalization": + mean = model_name_to_val_dict[node.input[3]] + for val in mean: + assert val == 0 + if __name__ == "__main__": - main() + main() diff --git a/Athos/ONNXCompiler/onnx_run.py b/Athos/ONNXCompiler/onnx_run.py index 905da810..046bcc7a 100644 --- a/Athos/ONNXCompiler/onnx_run.py +++ b/Athos/ONNXCompiler/onnx_run.py @@ -1,5 +1,4 @@ - -''' +""" Authors: Shubham Ugare. @@ -21,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import numpy as np import onnxruntime @@ -31,37 +30,40 @@ from onnx import helper # First read the ONNX file -if (len(sys.argv) < 2): - print("TF python file unspecified.", file=sys.stderr) - exit(1) +if len(sys.argv) < 2: + print("TF python file unspecified.", file=sys.stderr) + exit(1) file_name = sys.argv[1] -file_path = 'models/' + file_name -model_name = file_name[:-5] # name without the '.onnx' extension +file_path = "models/" + file_name +model_name = file_name[:-5] # name without the '.onnx' extension model = onnx.load(file_path) -sess = onnxruntime.InferenceSession(file_path) +sess = onnxruntime.InferenceSession(file_path) -x = np.load('debug/' + model_name + '/' + model_name + '_input.npy') +x = np.load("debug/" + model_name + "/" + model_name + "_input.npy") x = x.astype(np.float32) input_name = model.graph.input[0].name -if (len(sys.argv) > 2): - intermediate_layer_value_info = helper.ValueInfoProto() - intermediate_layer_value_info.name = sys.argv[2] - model.graph.output.extend([intermediate_layer_value_info]) - onnx.save(model, file_path + '_1') - sess = onnxruntime.InferenceSession(file_path + '_1') - pred = sess.run([intermediate_layer_value_info.name], {input_name: x}) - np.save('debug/' + model_name + '/' + model_name + '_debug', pred) - with open('debug/onnx_debug.txt', 'w') as f: - f.write(common.numpy_float_array_to_float_val_str(pred)) - print("Saving the onnx runtime intermediate output for " + intermediate_layer_value_info.name) - exit() +if len(sys.argv) > 2: + intermediate_layer_value_info = helper.ValueInfoProto() + intermediate_layer_value_info.name = sys.argv[2] + model.graph.output.extend([intermediate_layer_value_info]) + onnx.save(model, file_path + "_1") + sess = onnxruntime.InferenceSession(file_path + "_1") + pred = sess.run([intermediate_layer_value_info.name], {input_name: x}) + np.save("debug/" + model_name + "/" + model_name + "_debug", pred) + with open("debug/onnx_debug.txt", "w") as f: + f.write(common.numpy_float_array_to_float_val_str(pred)) + print( + "Saving the onnx runtime intermediate output for " + + intermediate_layer_value_info.name + ) + exit() pred = sess.run(None, {input_name: x}) -np.save('debug/' + model_name + '/' + model_name + '_output', pred) -with open('debug/onnx_output.txt', 'w') as f: - f.write(common.numpy_float_array_to_float_val_str(pred)) +np.save("debug/" + model_name + "/" + model_name + "_output", pred) +with open("debug/onnx_output.txt", "w") as f: + f.write(common.numpy_float_array_to_float_val_str(pred)) output_dims = common.proto_val_to_dimension_tuple(model.graph.output[0]) print("Saving the onnx runtime output of dimension " + str(output_dims)) diff --git a/Athos/ONNXCompiler/onnx_run_tf.py b/Athos/ONNXCompiler/onnx_run_tf.py index 4be6d2cd..50cb9557 100644 --- a/Athos/ONNXCompiler/onnx_run_tf.py +++ b/Athos/ONNXCompiler/onnx_run_tf.py @@ -1,5 +1,4 @@ - -''' +""" Authors: Shubham Ugare. @@ -21,12 +20,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" -''' +""" onnx_run is faster but may not support all operations onnx_run_tf uses tensorflow backend to run the inference -''' +""" import numpy as np import common @@ -36,62 +35,70 @@ from onnx_tf.backend import prepare from onnx import TensorProto + def main(): - # First read the ONNX file - if (len(sys.argv) < 2): - print("TF python file unspecified.", file=sys.stderr) - exit(1) - - file_name = sys.argv[1] - file_path = 'models/' + file_name - model_name = file_name[:-5] # name without the '.onnx' extension - model = onnx.load(file_path) - model = preprocess_for_tf(model) - - x = np.load('debug/' + model_name + '/' + model_name + '_input.npy') - x = x.astype(np.float32) - - input_name = model.graph.input[0].name - output_name = model.graph.output[0].name - - if (len(sys.argv) > 2): - intermediate_layer_value_info = helper.ValueInfoProto() - intermediate_layer_value_info_name = 'tf_' + sys.argv[2] - intermediate_layer_value_info = helper.make_tensor_value_info(intermediate_layer_value_info_name, TensorProto.FLOAT, []) - model.graph.output.extend([intermediate_layer_value_info]) - output = prepare(model).run(x) - pred = getattr(output, intermediate_layer_value_info_name) - np.save('debug/' + model_name + '/' + model_name + '_debug', pred) - with open('debug/onnx_debug.txt', 'w') as f: - f.write(common.numpy_float_array_to_float_val_str(pred)) - print("Saving the onnx runtime intermediate output for " + intermediate_layer_value_info.name) - exit() - - output = prepare(model).run(x) - pred = getattr(output, output_name) - np.save('debug/' + model_name + '/' + model_name + '_output', pred) - with open('debug/onnx_output.txt', 'w') as f: - f.write(common.numpy_float_array_to_float_val_str(pred)) - output_dims = common.proto_val_to_dimension_tuple(model.graph.output[0]) - print("Saving the onnx runtime output of dimension " + str(output_dims)) + # First read the ONNX file + if len(sys.argv) < 2: + print("TF python file unspecified.", file=sys.stderr) + exit(1) + + file_name = sys.argv[1] + file_path = "models/" + file_name + model_name = file_name[:-5] # name without the '.onnx' extension + model = onnx.load(file_path) + model = preprocess_for_tf(model) + + x = np.load("debug/" + model_name + "/" + model_name + "_input.npy") + x = x.astype(np.float32) + + input_name = model.graph.input[0].name + output_name = model.graph.output[0].name + + if len(sys.argv) > 2: + intermediate_layer_value_info = helper.ValueInfoProto() + intermediate_layer_value_info_name = "tf_" + sys.argv[2] + intermediate_layer_value_info = helper.make_tensor_value_info( + intermediate_layer_value_info_name, TensorProto.FLOAT, [] + ) + model.graph.output.extend([intermediate_layer_value_info]) + output = prepare(model).run(x) + pred = getattr(output, intermediate_layer_value_info_name) + np.save("debug/" + model_name + "/" + model_name + "_debug", pred) + with open("debug/onnx_debug.txt", "w") as f: + f.write(common.numpy_float_array_to_float_val_str(pred)) + print( + "Saving the onnx runtime intermediate output for " + + intermediate_layer_value_info.name + ) + exit() + + output = prepare(model).run(x) + pred = getattr(output, output_name) + np.save("debug/" + model_name + "/" + model_name + "_output", pred) + with open("debug/onnx_output.txt", "w") as f: + f.write(common.numpy_float_array_to_float_val_str(pred)) + output_dims = common.proto_val_to_dimension_tuple(model.graph.output[0]) + print("Saving the onnx runtime output of dimension " + str(output_dims)) + def preprocess_for_tf(model): - for init_vals in model.graph.initializer: - init_vals.name = 'tf_' + init_vals.name + for init_vals in model.graph.initializer: + init_vals.name = "tf_" + init_vals.name + + for inp in model.graph.input: + inp.name = "tf_" + inp.name - for inp in model.graph.input: - inp.name = 'tf_' + inp.name + for op in model.graph.output: + op.name = "tf_" + op.name - for op in model.graph.output: - op.name = 'tf_' + op.name + for node in model.graph.node: + node.name = "tf_" + node.name + for i in range(len(node.input)): + node.input[i] = "tf_" + node.input[i] + for i in range(len(node.output)): + node.output[i] = "tf_" + node.output[i] + return model - for node in model.graph.node: - node.name = 'tf_' + node.name - for i in range(len(node.input)): - node.input[i] = 'tf_' + node.input[i] - for i in range(len(node.output)): - node.output[i] = 'tf_' + node.output[i] - return model if __name__ == "__main__": - main() + main() diff --git a/Athos/ONNXCompiler/process_onnx.py b/Athos/ONNXCompiler/process_onnx.py index 5b9dfbe4..0e33d8cf 100644 --- a/Athos/ONNXCompiler/process_onnx.py +++ b/Athos/ONNXCompiler/process_onnx.py @@ -1,5 +1,4 @@ - -''' +""" Authors: Shubham Ugare. Copyright: Copyright (c) 2020 Microsoft Research @@ -18,15 +17,15 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import os, sys -#Add SeeDot directory to path -sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'SeeDot')) +# Add SeeDot directory to path +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "SeeDot")) # For this warning: https://stackoverflow.com/questions/47068709/your-cpu-supports-instructions-that-this-tensorflow-binary-was-not-compiled-to-u -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" import _pickle as pickle import onnx @@ -35,140 +34,211 @@ from ONNXNodesAST import ONNXNodesAST from onnx.helper import make_tensor_value_info from onnx import TensorProto -from AST.PrintAST import PrintAST +from AST.PrintAST import PrintAST from AST.MtdAST import MtdAST import numpy import common import numpy as np + np.set_printoptions(threshold=np.inf) DEBUG = False out_var_prefix = "J" + def main(): - sys.setrecursionlimit(10000) - # First read the ONNX file - if (len(sys.argv) < 2): - print("TF python file unspecified.", file=sys.stderr) - exit(1) - file_name = sys.argv[1] - file_path = 'models/' + file_name - model_name = file_name[:-5] # name without the '.onnx' extension - - # load the model and extract the graph - model = onnx.load(file_path) - graph_def = model.graph - - print(model.graph.value_info) - # Before shape inference (model.graph.value_info) should have shapes of all the variables and constants - model.graph.value_info.append(make_tensor_value_info(model.graph.input[0].name, TensorProto.FLOAT, common.proto_val_to_dimension_tuple(model.graph.input[0]))) - model.graph.value_info.append(make_tensor_value_info(model.graph.output[0].name, TensorProto.FLOAT, common.proto_val_to_dimension_tuple(model.graph.output[0]))) - - print(model.graph.value_info) - - for init_vals in model.graph.initializer: - model.graph.value_info.append(make_tensor_value_info(init_vals.name, TensorProto.FLOAT, tuple(init_vals.dims))) - - if(DEBUG): - print("Shape inference *****************") - print(model.graph.value_info) - - inferred_model = onnx.shape_inference.infer_shapes(model) - - if(DEBUG): - print("Printing shape ******************") - print(inferred_model.graph.value_info) - print("Done ******************") - - # value_info: dictionary of name -> (type, dimension tuple) - value_info = {} - for val in inferred_model.graph.value_info: - value_info[val.name] = (val.type.tensor_type.elem_type, common.proto_val_to_dimension_tuple(val)) - - # Iterate through the ONNX graph nodes and translate them to SeeDot AST nodes - program = None - innermost_let_ast_node = None - node_name_to_out_var_dict = {} - out_var_count = 0 - mtdAST = MtdAST() - - (program, innermost_let_ast_node, out_var_count) = process_input_variables(program, innermost_let_ast_node, node_name_to_out_var_dict, out_var_count, mtdAST, graph_def, value_info) - - process_onnx_nodes(innermost_let_ast_node, node_name_to_out_var_dict, out_var_count, mtdAST, graph_def, value_info) - - PrintAST().visit(program) - - common.write_debug_info(node_name_to_out_var_dict) - - with open('debug/'+model_name+'/' +model_name + '.pkl', 'wb') as f: - pickle.dump(program, f) - -def process_input_variables(program, innermost_let_ast_node, node_name_to_out_var_dict, out_var_count, mtdAST, graph_def, value_info): - node = graph_def.input[0] - curAst = ONNXNodesAST.Input(node, value_info, node_name_to_out_var_dict, 1) - mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : 'Input', - AST.ASTNode.mtdKeyTFNodeName : node.name} - cur_out_var_ast_node = AST.ID(node.name) - - if program: - assert(type(innermost_let_ast_node) is AST.Let) - newNode = AST.Let(cur_out_var_ast_node, curAst, cur_out_var_ast_node) - mtdAST.visit(newNode, mtdForCurAST) - # Updating the innermost Let AST node and the expression for previous Let Node - innermost_let_ast_node.expr = newNode - innermost_let_ast_node = newNode - else: - innermost_let_ast_node = AST.Let(cur_out_var_ast_node, curAst, cur_out_var_ast_node) - mtdAST.visit(innermost_let_ast_node, mtdForCurAST) - innermost_let_ast_node.depth = 0 - program = innermost_let_ast_node - - node_name_to_out_var_dict[node.name] = node.name - - for node in graph_def.initializer: - if(DEBUG): - print("Node information") - print(node) - - curAst = ONNXNodesAST.Input(node, value_info, node_name_to_out_var_dict) - mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : 'Input', - AST.ASTNode.mtdKeyTFNodeName : node.name} - if (curAst is None): - continue - - cur_out_var_ast_node = AST.ID(node.name) - - if program: - assert(type(innermost_let_ast_node) is AST.Let) - newNode = AST.Let(cur_out_var_ast_node, curAst, cur_out_var_ast_node) - mtdAST.visit(newNode, mtdForCurAST) - # Updating the innermost Let AST node and the expression for previous Let Node - innermost_let_ast_node.expr = newNode - innermost_let_ast_node = newNode - else: - innermost_let_ast_node = AST.Let(cur_out_var_ast_node, curAst, cur_out_var_ast_node) - mtdAST.visit(innermost_let_ast_node, mtdForCurAST) - innermost_let_ast_node.depth = 0 - program = innermost_let_ast_node - - node_name_to_out_var_dict[node.name] = node.name - return (program, innermost_let_ast_node, out_var_count) - -def process_onnx_nodes(innermost_let_ast_node, node_name_to_out_var_dict, out_var_count, mtdAST, graph_def, value_info): - for node in graph_def.node: - if(DEBUG): - print("Node information") - print(node) - - print("Processing " + node.name + "\n") - mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : node.op_type, - AST.ASTNode.mtdKeyTFNodeName : node.name} - - func = getattr(ONNXNodesAST, node.op_type) - (innermost_let_ast_node, out_var_count) = func(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST) - - assert(type(innermost_let_ast_node) is AST.Let) + sys.setrecursionlimit(10000) + # First read the ONNX file + if len(sys.argv) < 2: + print("TF python file unspecified.", file=sys.stderr) + exit(1) + file_name = sys.argv[1] + file_path = "models/" + file_name + model_name = file_name[:-5] # name without the '.onnx' extension + + # load the model and extract the graph + model = onnx.load(file_path) + graph_def = model.graph + + print(model.graph.value_info) + # Before shape inference (model.graph.value_info) should have shapes of all the variables and constants + model.graph.value_info.append( + make_tensor_value_info( + model.graph.input[0].name, + TensorProto.FLOAT, + common.proto_val_to_dimension_tuple(model.graph.input[0]), + ) + ) + model.graph.value_info.append( + make_tensor_value_info( + model.graph.output[0].name, + TensorProto.FLOAT, + common.proto_val_to_dimension_tuple(model.graph.output[0]), + ) + ) + + print(model.graph.value_info) + + for init_vals in model.graph.initializer: + model.graph.value_info.append( + make_tensor_value_info( + init_vals.name, TensorProto.FLOAT, tuple(init_vals.dims) + ) + ) + + if DEBUG: + print("Shape inference *****************") + print(model.graph.value_info) + + inferred_model = onnx.shape_inference.infer_shapes(model) + + if DEBUG: + print("Printing shape ******************") + print(inferred_model.graph.value_info) + print("Done ******************") + + # value_info: dictionary of name -> (type, dimension tuple) + value_info = {} + for val in inferred_model.graph.value_info: + value_info[val.name] = ( + val.type.tensor_type.elem_type, + common.proto_val_to_dimension_tuple(val), + ) + + # Iterate through the ONNX graph nodes and translate them to SeeDot AST nodes + program = None + innermost_let_ast_node = None + node_name_to_out_var_dict = {} + out_var_count = 0 + mtdAST = MtdAST() + + (program, innermost_let_ast_node, out_var_count) = process_input_variables( + program, + innermost_let_ast_node, + node_name_to_out_var_dict, + out_var_count, + mtdAST, + graph_def, + value_info, + ) + + process_onnx_nodes( + innermost_let_ast_node, + node_name_to_out_var_dict, + out_var_count, + mtdAST, + graph_def, + value_info, + ) + + PrintAST().visit(program) + + common.write_debug_info(node_name_to_out_var_dict) + + with open("debug/" + model_name + "/" + model_name + ".pkl", "wb") as f: + pickle.dump(program, f) + + +def process_input_variables( + program, + innermost_let_ast_node, + node_name_to_out_var_dict, + out_var_count, + mtdAST, + graph_def, + value_info, +): + node = graph_def.input[0] + curAst = ONNXNodesAST.Input(node, value_info, node_name_to_out_var_dict, 1) + mtdForCurAST = { + AST.ASTNode.mtdKeyTFOpName: "Input", + AST.ASTNode.mtdKeyTFNodeName: node.name, + } + cur_out_var_ast_node = AST.ID(node.name) + + if program: + assert type(innermost_let_ast_node) is AST.Let + newNode = AST.Let(cur_out_var_ast_node, curAst, cur_out_var_ast_node) + mtdAST.visit(newNode, mtdForCurAST) + # Updating the innermost Let AST node and the expression for previous Let Node + innermost_let_ast_node.expr = newNode + innermost_let_ast_node = newNode + else: + innermost_let_ast_node = AST.Let( + cur_out_var_ast_node, curAst, cur_out_var_ast_node + ) + mtdAST.visit(innermost_let_ast_node, mtdForCurAST) + innermost_let_ast_node.depth = 0 + program = innermost_let_ast_node + + node_name_to_out_var_dict[node.name] = node.name + + for node in graph_def.initializer: + if DEBUG: + print("Node information") + print(node) + + curAst = ONNXNodesAST.Input(node, value_info, node_name_to_out_var_dict) + mtdForCurAST = { + AST.ASTNode.mtdKeyTFOpName: "Input", + AST.ASTNode.mtdKeyTFNodeName: node.name, + } + if curAst is None: + continue + + cur_out_var_ast_node = AST.ID(node.name) + + if program: + assert type(innermost_let_ast_node) is AST.Let + newNode = AST.Let(cur_out_var_ast_node, curAst, cur_out_var_ast_node) + mtdAST.visit(newNode, mtdForCurAST) + # Updating the innermost Let AST node and the expression for previous Let Node + innermost_let_ast_node.expr = newNode + innermost_let_ast_node = newNode + else: + innermost_let_ast_node = AST.Let( + cur_out_var_ast_node, curAst, cur_out_var_ast_node + ) + mtdAST.visit(innermost_let_ast_node, mtdForCurAST) + innermost_let_ast_node.depth = 0 + program = innermost_let_ast_node + + node_name_to_out_var_dict[node.name] = node.name + return (program, innermost_let_ast_node, out_var_count) + + +def process_onnx_nodes( + innermost_let_ast_node, + node_name_to_out_var_dict, + out_var_count, + mtdAST, + graph_def, + value_info, +): + for node in graph_def.node: + if DEBUG: + print("Node information") + print(node) + + print("Processing " + node.name + "\n") + mtdForCurAST = { + AST.ASTNode.mtdKeyTFOpName: node.op_type, + AST.ASTNode.mtdKeyTFNodeName: node.name, + } + + func = getattr(ONNXNodesAST, node.op_type) + (innermost_let_ast_node, out_var_count) = func( + node, + value_info, + node_name_to_out_var_dict, + innermost_let_ast_node, + out_var_count, + mtdAST, + ) + + assert type(innermost_let_ast_node) is AST.Let + if __name__ == "__main__": - main() + main() diff --git a/Athos/ONNXCompiler/test/test.py b/Athos/ONNXCompiler/test/test.py index f2e45991..78375571 100644 --- a/Athos/ONNXCompiler/test/test.py +++ b/Athos/ONNXCompiler/test/test.py @@ -1,4 +1,4 @@ -''' +""" Authors: Shubham Ugare. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import onnx @@ -34,240 +34,268 @@ import time import hashlib -class TestNode(unittest.TestCase): - def _get_rnd_float32(self, low=-1.0, high=1.0, shape=None): - output = np.random.uniform(low, high, shape) - cnt = 1 - for val in shape: cnt*=val - if shape == None: - return np.float32(output) - else: - return output.astype(np.float32).reshape(cnt).tolist() - - def check_result(self, graph, name): - current_milli_time = lambda: str(int(round(time.time() * 1000))) - name = name + "_" + current_milli_time() - model = onnx.helper.make_model(graph, producer_name='onnx-compiler-test') - onnx.save(model, 'models/' + name + '.onnx') - - old_hash = hashlib.md5(open('debug/cpp_output.txt','rb').read()).hexdigest() - - bashCommand = './compile.sh ' + name - process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) - output, error = process.communicate() - - print(output) - print(error) - new_hash = hashlib.md5(open('debug/cpp_output.txt','rb').read()).hexdigest() - - self.assertNotEqual(old_hash, new_hash, 'the compilation did not terminate') - - res_onnx = common.extract_txt_to_numpy_array('debug/onnx_output.txt') - res_cpp = common.extract_txt_to_numpy_array('debug/cpp_output.txt') - - np.save('res_onnx', res_onnx) - np.save('res_cpp', res_cpp) - - self.assertIsNone(error, 'error is non None') - np.testing.assert_almost_equal(res_cpp, res_onnx, decimal=4) - - - def test_conv2d(self): - name = "conv2d" - state_in = helper.make_tensor_value_info('state_in', - TensorProto.FLOAT, [1, 3, 10, 10]) - state_out = helper.make_tensor_value_info('state_out', - TensorProto.FLOAT, [1, 6, 5, 5]) - node_def = helper.make_node("Conv", ['state_in', 'weight'], ['state_out'], - pads=[1, 1, 1, 1], strides=[2, 2], kernel_shape=[3, 3], group=3) - - weight_shape = [6, 1, 3, 3] - weight_val = self._get_rnd_float32(shape=weight_shape) - - weight = helper.make_tensor('weight', TensorProto.FLOAT, weight_shape, weight_val) - - graph = helper.make_graph( - [node_def], - name, - [state_in], - [state_out], - [weight] - ) - self.check_result(graph, name) - - - def test_conv3d(self): - name = "conv3d" - state_in = helper.make_tensor_value_info('state_in',TensorProto.FLOAT, [1, 2, 4, 16, 16]) - state_out = helper.make_tensor_value_info('state_out', - TensorProto.FLOAT, [1, 2, 4, 16, 16]) - node_def = helper.make_node("Conv", ['state_in', 'weight'], ['state_out'], - pads=[1, 1, 1, 1, 1, 1], strides=[1, 1, 1], kernel_shape=[3, 3, 3]) - - weight_shape = [2, 2, 3, 3, 3] - weight_val = self._get_rnd_float32(shape=weight_shape) - np.save('weight', weight_val) - - weight = helper.make_tensor('weight', TensorProto.FLOAT, weight_shape, weight_val) - - graph = helper.make_graph( - [node_def], - name, - [state_in], - [state_out], - [weight] - ) - self.check_result(graph, name) - - def test_conv_transpose(self): - name = "conv_transpose" - state_in = helper.make_tensor_value_info('state_in', - TensorProto.FLOAT, [1, 3, 10, 10]) - state_out = helper.make_tensor_value_info('state_out', - TensorProto.FLOAT, [1, 5, 19, 19]) - node_def = helper.make_node("ConvTranspose", ['state_in', 'weight'], ['state_out'], - pads=[1, 1, 1, 1], strides=[2, 2], kernel_shape=[3, 3]) - - weight_shape = [3, 5, 3, 3] - weight_val = self._get_rnd_float32(shape=weight_shape) - - weight = helper.make_tensor('weight', TensorProto.FLOAT, weight_shape, weight_val) - - graph = helper.make_graph( - [node_def], - name, - [state_in], - [state_out], - [weight] - ) - - self.check_result(graph, name) - - # For this to run onnx_run_tf.py should be used in the compile script - # since onnxruntime does not support convtranspose3d - def test_conv_transpose3d(self): - name = "conv3dTranspose" - state_in = helper.make_tensor_value_info('state_in', - TensorProto.FLOAT, [1, 3, 10, 10, 10]) - state_out = helper.make_tensor_value_info('state_out', - TensorProto.FLOAT, [1, 5, 19, 19, 19]) - node_def = helper.make_node("ConvTranspose", ['state_in', 'weight', 'bias'], ['state_out'], - # check with pads which are not 1 - pads=[1, 1, 1, 1, 1, 1], strides=[2, 2, 2], kernel_shape=[3, 3, 3]) - - weight_shape = [3, 5, 3, 3, 3] - weight_val = self._get_rnd_float32(shape=weight_shape) - bias_shape = [5] - bias_val = self._get_rnd_float32(shape=bias_shape) - - weight = helper.make_tensor('weight', TensorProto.FLOAT, weight_shape, weight_val) - bias = helper.make_tensor('bias', TensorProto.FLOAT, bias_shape, bias_val) - - graph = helper.make_graph( - [node_def], - name, - [state_in], - [state_out], - [weight, bias] - ) - self.check_result(graph, name) - - def test_relu(self): - name = "relu" - state_in = helper.make_tensor_value_info('state_in', - TensorProto.FLOAT, [1, 3, 10, 10]) - state_out = helper.make_tensor_value_info('state_out', - TensorProto.FLOAT, [1, 3, 10, 10]) - node_def = helper.make_node("Relu", ['state_in'], ['state_out']) - graph = helper.make_graph( - [node_def], - name, - [state_in], - [state_out], - [] - ) - self.check_result(graph, name) - - def test_pad(self): - name = "pad" - state_in = helper.make_tensor_value_info('state_in', TensorProto.FLOAT, [1, 3, 10, 10]) - pads = helper.make_tensor_value_info('pads', TensorProto.INT64, [8]) - pad_init = numpy_helper.from_array(np.array([0,0,1,1,0,0,1,1], dtype=int), name='pads') - const_val = helper.make_tensor_value_info('const_val', TensorProto.FLOAT, [1]) - const_val_init = numpy_helper.from_array(np.array([0.0], dtype=np.float32), name='const_val') - state_out = helper.make_tensor_value_info('state_out', TensorProto.FLOAT, [1,3,12,12]) - node_def = helper.make_node("Pad", ['state_in', 'pads', 'const_val'], ['state_out'], mode="constant") - graph = helper.make_graph([node_def],name,[state_in, pads, const_val],[state_out],initializer=[pad_init, const_val_init]) - self.check_result(graph, name) - - - def test_relu3d(self): - name = "relu3d" - state_in = helper.make_tensor_value_info('state_in', - TensorProto.FLOAT, [1, 3, 7, 7, 7]) - state_out = helper.make_tensor_value_info('state_out', - TensorProto.FLOAT, [1, 3, 7, 7, 7]) - node_def = helper.make_node("Relu", ['state_in'], ['state_out']) - graph = helper.make_graph( - [node_def], - name, - [state_in], - [state_out], - [] - ) - self.check_result(graph, name) - - def test_reducemean(self): - name = "reducemean" - state_in = helper.make_tensor_value_info('state_in', - TensorProto.FLOAT, [1, 1024, 7, 7]) - state_out = helper.make_tensor_value_info('state_out', - TensorProto.FLOAT, [1, 1024]) - node_def = helper.make_node("ReduceMean", ['state_in'], ['state_out'], axes=[2,3], keepdims=0) - graph = helper.make_graph( - [node_def], - name, - [state_in], - [state_out], - [] - ) - self.check_result(graph, name) - - def test_batchnormalization(self): - name = "batchnormalization" - state_in = helper.make_tensor_value_info('state_in', - TensorProto.FLOAT, [1, 24, 10, 10]) - state_out = helper.make_tensor_value_info('state_out', - TensorProto.FLOAT, [1, 24, 10, 10]) - node_def = helper.make_node("BatchNormalization", ['state_in', 'weight', 'bias','mean','var'], ['state_out'], - momentum=0.8999999761581421) - - weight_shape = [24] - weight_val = self._get_rnd_float32(shape=weight_shape) - weight = helper.make_tensor('weight', TensorProto.FLOAT, weight_shape, weight_val) - - bias_shape = [24] - bias_val = self._get_rnd_float32(shape=weight_shape) - bias = helper.make_tensor('bias', TensorProto.FLOAT, bias_shape, bias_val) - - mean_shape = [24] - mean_val = self._get_rnd_float32(shape=weight_shape) - mean = helper.make_tensor('mean', TensorProto.FLOAT, mean_shape, mean_val) - - - var_shape = [24] - var_val = self._get_rnd_float32(shape=weight_shape, low=0, high=1) - var = helper.make_tensor('var', TensorProto.FLOAT, var_shape, var_val) - - graph = helper.make_graph( - [node_def], - name, - [state_in], - [state_out], - [weight, bias, mean, var] - ) - self.check_result(graph, name) - -if __name__ == '__main__': - unittest.main() +class TestNode(unittest.TestCase): + def _get_rnd_float32(self, low=-1.0, high=1.0, shape=None): + output = np.random.uniform(low, high, shape) + cnt = 1 + for val in shape: + cnt *= val + if shape == None: + return np.float32(output) + else: + return output.astype(np.float32).reshape(cnt).tolist() + + def check_result(self, graph, name): + current_milli_time = lambda: str(int(round(time.time() * 1000))) + name = name + "_" + current_milli_time() + model = onnx.helper.make_model(graph, producer_name="onnx-compiler-test") + onnx.save(model, "models/" + name + ".onnx") + + old_hash = hashlib.md5(open("debug/cpp_output.txt", "rb").read()).hexdigest() + + bashCommand = "./compile.sh " + name + process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) + output, error = process.communicate() + + print(output) + print(error) + new_hash = hashlib.md5(open("debug/cpp_output.txt", "rb").read()).hexdigest() + + self.assertNotEqual(old_hash, new_hash, "the compilation did not terminate") + + res_onnx = common.extract_txt_to_numpy_array("debug/onnx_output.txt") + res_cpp = common.extract_txt_to_numpy_array("debug/cpp_output.txt") + + np.save("res_onnx", res_onnx) + np.save("res_cpp", res_cpp) + + self.assertIsNone(error, "error is non None") + np.testing.assert_almost_equal(res_cpp, res_onnx, decimal=4) + + def test_conv2d(self): + name = "conv2d" + state_in = helper.make_tensor_value_info( + "state_in", TensorProto.FLOAT, [1, 3, 10, 10] + ) + state_out = helper.make_tensor_value_info( + "state_out", TensorProto.FLOAT, [1, 6, 5, 5] + ) + node_def = helper.make_node( + "Conv", + ["state_in", "weight"], + ["state_out"], + pads=[1, 1, 1, 1], + strides=[2, 2], + kernel_shape=[3, 3], + group=3, + ) + + weight_shape = [6, 1, 3, 3] + weight_val = self._get_rnd_float32(shape=weight_shape) + + weight = helper.make_tensor( + "weight", TensorProto.FLOAT, weight_shape, weight_val + ) + + graph = helper.make_graph([node_def], name, [state_in], [state_out], [weight]) + self.check_result(graph, name) + + def test_conv3d(self): + name = "conv3d" + state_in = helper.make_tensor_value_info( + "state_in", TensorProto.FLOAT, [1, 2, 4, 16, 16] + ) + state_out = helper.make_tensor_value_info( + "state_out", TensorProto.FLOAT, [1, 2, 4, 16, 16] + ) + node_def = helper.make_node( + "Conv", + ["state_in", "weight"], + ["state_out"], + pads=[1, 1, 1, 1, 1, 1], + strides=[1, 1, 1], + kernel_shape=[3, 3, 3], + ) + + weight_shape = [2, 2, 3, 3, 3] + weight_val = self._get_rnd_float32(shape=weight_shape) + np.save("weight", weight_val) + + weight = helper.make_tensor( + "weight", TensorProto.FLOAT, weight_shape, weight_val + ) + + graph = helper.make_graph([node_def], name, [state_in], [state_out], [weight]) + self.check_result(graph, name) + + def test_conv_transpose(self): + name = "conv_transpose" + state_in = helper.make_tensor_value_info( + "state_in", TensorProto.FLOAT, [1, 3, 10, 10] + ) + state_out = helper.make_tensor_value_info( + "state_out", TensorProto.FLOAT, [1, 5, 19, 19] + ) + node_def = helper.make_node( + "ConvTranspose", + ["state_in", "weight"], + ["state_out"], + pads=[1, 1, 1, 1], + strides=[2, 2], + kernel_shape=[3, 3], + ) + + weight_shape = [3, 5, 3, 3] + weight_val = self._get_rnd_float32(shape=weight_shape) + + weight = helper.make_tensor( + "weight", TensorProto.FLOAT, weight_shape, weight_val + ) + + graph = helper.make_graph([node_def], name, [state_in], [state_out], [weight]) + + self.check_result(graph, name) + + # For this to run onnx_run_tf.py should be used in the compile script + # since onnxruntime does not support convtranspose3d + def test_conv_transpose3d(self): + name = "conv3dTranspose" + state_in = helper.make_tensor_value_info( + "state_in", TensorProto.FLOAT, [1, 3, 10, 10, 10] + ) + state_out = helper.make_tensor_value_info( + "state_out", TensorProto.FLOAT, [1, 5, 19, 19, 19] + ) + node_def = helper.make_node( + "ConvTranspose", + ["state_in", "weight", "bias"], + ["state_out"], + # check with pads which are not 1 + pads=[1, 1, 1, 1, 1, 1], + strides=[2, 2, 2], + kernel_shape=[3, 3, 3], + ) + + weight_shape = [3, 5, 3, 3, 3] + weight_val = self._get_rnd_float32(shape=weight_shape) + bias_shape = [5] + bias_val = self._get_rnd_float32(shape=bias_shape) + + weight = helper.make_tensor( + "weight", TensorProto.FLOAT, weight_shape, weight_val + ) + bias = helper.make_tensor("bias", TensorProto.FLOAT, bias_shape, bias_val) + + graph = helper.make_graph( + [node_def], name, [state_in], [state_out], [weight, bias] + ) + self.check_result(graph, name) + + def test_relu(self): + name = "relu" + state_in = helper.make_tensor_value_info( + "state_in", TensorProto.FLOAT, [1, 3, 10, 10] + ) + state_out = helper.make_tensor_value_info( + "state_out", TensorProto.FLOAT, [1, 3, 10, 10] + ) + node_def = helper.make_node("Relu", ["state_in"], ["state_out"]) + graph = helper.make_graph([node_def], name, [state_in], [state_out], []) + self.check_result(graph, name) + + def test_pad(self): + name = "pad" + state_in = helper.make_tensor_value_info( + "state_in", TensorProto.FLOAT, [1, 3, 10, 10] + ) + pads = helper.make_tensor_value_info("pads", TensorProto.INT64, [8]) + pad_init = numpy_helper.from_array( + np.array([0, 0, 1, 1, 0, 0, 1, 1], dtype=int), name="pads" + ) + const_val = helper.make_tensor_value_info("const_val", TensorProto.FLOAT, [1]) + const_val_init = numpy_helper.from_array( + np.array([0.0], dtype=np.float32), name="const_val" + ) + state_out = helper.make_tensor_value_info( + "state_out", TensorProto.FLOAT, [1, 3, 12, 12] + ) + node_def = helper.make_node( + "Pad", ["state_in", "pads", "const_val"], ["state_out"], mode="constant" + ) + graph = helper.make_graph( + [node_def], + name, + [state_in, pads, const_val], + [state_out], + initializer=[pad_init, const_val_init], + ) + self.check_result(graph, name) + + def test_relu3d(self): + name = "relu3d" + state_in = helper.make_tensor_value_info( + "state_in", TensorProto.FLOAT, [1, 3, 7, 7, 7] + ) + state_out = helper.make_tensor_value_info( + "state_out", TensorProto.FLOAT, [1, 3, 7, 7, 7] + ) + node_def = helper.make_node("Relu", ["state_in"], ["state_out"]) + graph = helper.make_graph([node_def], name, [state_in], [state_out], []) + self.check_result(graph, name) + + def test_reducemean(self): + name = "reducemean" + state_in = helper.make_tensor_value_info( + "state_in", TensorProto.FLOAT, [1, 1024, 7, 7] + ) + state_out = helper.make_tensor_value_info( + "state_out", TensorProto.FLOAT, [1, 1024] + ) + node_def = helper.make_node( + "ReduceMean", ["state_in"], ["state_out"], axes=[2, 3], keepdims=0 + ) + graph = helper.make_graph([node_def], name, [state_in], [state_out], []) + self.check_result(graph, name) + + def test_batchnormalization(self): + name = "batchnormalization" + state_in = helper.make_tensor_value_info( + "state_in", TensorProto.FLOAT, [1, 24, 10, 10] + ) + state_out = helper.make_tensor_value_info( + "state_out", TensorProto.FLOAT, [1, 24, 10, 10] + ) + node_def = helper.make_node( + "BatchNormalization", + ["state_in", "weight", "bias", "mean", "var"], + ["state_out"], + momentum=0.8999999761581421, + ) + + weight_shape = [24] + weight_val = self._get_rnd_float32(shape=weight_shape) + weight = helper.make_tensor( + "weight", TensorProto.FLOAT, weight_shape, weight_val + ) + + bias_shape = [24] + bias_val = self._get_rnd_float32(shape=weight_shape) + bias = helper.make_tensor("bias", TensorProto.FLOAT, bias_shape, bias_val) + + mean_shape = [24] + mean_val = self._get_rnd_float32(shape=weight_shape) + mean = helper.make_tensor("mean", TensorProto.FLOAT, mean_shape, mean_val) + + var_shape = [24] + var_val = self._get_rnd_float32(shape=weight_shape, low=0, high=1) + var = helper.make_tensor("var", TensorProto.FLOAT, var_shape, var_val) + + graph = helper.make_graph( + [node_def], name, [state_in], [state_out], [weight, bias, mean, var] + ) + self.check_result(graph, name) + + +if __name__ == "__main__": + unittest.main() diff --git a/Athos/SeeDot/AST/AST.py b/Athos/SeeDot/AST/AST.py index a309a3cf..3c47b0fc 100644 --- a/Athos/SeeDot/AST/AST.py +++ b/Athos/SeeDot/AST/AST.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,106 +20,110 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" from enum import Enum, auto OperatorsSymbolDict = { - "ADD": '+', - "ClearMemPublic": 'clearmempublic', - "ClearMemSecret": 'clearmemsecret', - "CONV": '#', - "CONVTRANSPOSE": "#T", #ConvTranspose - "ElemWiseDiv": './', - "ElemWiseMul":'.*', - "Equal": '==', - "Floor": 'floor', - "Mean": 'mean', - "MUL": '*', - "RELU": 'relu', - "RSQRT": 'rsqrt', - "Shape": 'shape', - "SIGMOID": 'sigmoid', - "SQRT": 'sqrt', - "SUB": '-', - "TANH": 'tanh', + "ADD": "+", + "ClearMemPublic": "clearmempublic", + "ClearMemSecret": "clearmemsecret", + "CONV": "#", + "CONVTRANSPOSE": "#T", # ConvTranspose + "ElemWiseDiv": "./", + "ElemWiseMul": ".*", + "Equal": "==", + "Floor": "floor", + "Mean": "mean", + "MUL": "*", + "RELU": "relu", + "RSQRT": "rsqrt", + "Shape": "shape", + "SIGMOID": "sigmoid", + "SQRT": "sqrt", + "SUB": "-", + "TANH": "tanh", } + class Party(Enum): - SERVER = 0 - CLIENT = 1 + SERVER = 0 + CLIENT = 1 + class Operators(Enum): - ADD = auto() - SUB = auto() - MUL = auto() - CONV = auto() - CONVTRANSPOSE = auto() - RELU = auto() - TANH = auto() - SIGMOID = auto() - SQRT = auto() - RSQRT = auto() - Equal = auto() - ElemWiseMul = auto() - ElemWiseDiv = auto() - Floor = auto() - Shape = auto() - Mean = auto() - ClearMemSecret = auto() - ClearMemPublic = auto() - - def convSymbolToEnumValue(symbolStr): - enumStr = None - for k,v in OperatorsSymbolDict.items(): - if v==symbolStr: - enumStr = k - assert(enumStr is not None) - return Operators[enumStr] - - def findConvTransposePadding(i, i_prime, f, p_total, stride): - # The parameters have the following semantics: - # i = conv input img size - # i_prime = convTranspose input img Size - # f = filter size - # p_total = conv input padding total - # stride = conv input stride - p_total_tr = 2*f - p_total - 2 + ((i + p_total - f)%stride) - stride_tr = 1 - i_prime_tilde = i_prime + (i_prime-1)*(stride-1) - return [p_total_tr, stride_tr, i_prime_tilde] - - def findLeftRightPaddingFromTotalPadding(totalPadding): - leftPadding = totalPadding // 2 - rightPadding = totalPadding - leftPadding - return [leftPadding, rightPadding] - - def findConvOutputImgSize(imgSize, totalPadding, filterSize, stride): - return ((imgSize + totalPadding - filterSize) // stride) + 1 + ADD = auto() + SUB = auto() + MUL = auto() + CONV = auto() + CONVTRANSPOSE = auto() + RELU = auto() + TANH = auto() + SIGMOID = auto() + SQRT = auto() + RSQRT = auto() + Equal = auto() + ElemWiseMul = auto() + ElemWiseDiv = auto() + Floor = auto() + Shape = auto() + Mean = auto() + ClearMemSecret = auto() + ClearMemPublic = auto() + + def convSymbolToEnumValue(symbolStr): + enumStr = None + for k, v in OperatorsSymbolDict.items(): + if v == symbolStr: + enumStr = k + assert enumStr is not None + return Operators[enumStr] + + def findConvTransposePadding(i, i_prime, f, p_total, stride): + # The parameters have the following semantics: + # i = conv input img size + # i_prime = convTranspose input img Size + # f = filter size + # p_total = conv input padding total + # stride = conv input stride + p_total_tr = 2 * f - p_total - 2 + ((i + p_total - f) % stride) + stride_tr = 1 + i_prime_tilde = i_prime + (i_prime - 1) * (stride - 1) + return [p_total_tr, stride_tr, i_prime_tilde] + + def findLeftRightPaddingFromTotalPadding(totalPadding): + leftPadding = totalPadding // 2 + rightPadding = totalPadding - leftPadding + return [leftPadding, rightPadding] + + def findConvOutputImgSize(imgSize, totalPadding, filterSize, stride): + return ((imgSize + totalPadding - filterSize) // stride) + 1 + class PaddingKeysDict: - ConvDim = 2 #2D or 3D convolution, default to 2D - #Also used for convTranpose - FH = "FH" - FW = "FW" - FD = "FD" - zPadHLeft = "zPadHLeft" - zPadHRight = "zPadHRight" - zPadWLeft = "zPadWLeft" - zPadWRight = "zPadWRight" - zPadDLeft = "zPadDLeft" - zPadDRight = "zPadDRight" - strideH = "strideH" - strideW = "strideW" - strideD = "strideD" - inputImgH = "inputImgH" - inputImgW = "inputImgW" - inputImgD = "inputImgD" - outputImgH = "outputImgH" - outputImgW = "outputImgW" - outputImgD = "outputImgD" - paddingUsedStr = "paddingUsedStr" - group = "group" + ConvDim = 2 # 2D or 3D convolution, default to 2D + # Also used for convTranpose + FH = "FH" + FW = "FW" + FD = "FD" + zPadHLeft = "zPadHLeft" + zPadHRight = "zPadHRight" + zPadWLeft = "zPadWLeft" + zPadWRight = "zPadWRight" + zPadDLeft = "zPadDLeft" + zPadDRight = "zPadDRight" + strideH = "strideH" + strideW = "strideW" + strideD = "strideD" + inputImgH = "inputImgH" + inputImgW = "inputImgW" + inputImgD = "inputImgD" + outputImgH = "outputImgH" + outputImgW = "outputImgW" + outputImgD = "outputImgD" + paddingUsedStr = "paddingUsedStr" + group = "group" + # If this is marked true, each astNode checks the types of its inputs to confirm it satisfies the assumption # Turn this off to get speedup in compilation @@ -127,277 +131,333 @@ class PaddingKeysDict: # Represents expression. All other nodes are specific types of expr. class ASTNode: - mtdKeyTFOpName = "TFOpName" - mtdKeyTFNodeName = "TFNodeName" - def __init__(self): - self.gamma = {} - self.metadata = {} - self.decls = {} - self.depth = 0 - self.optidict = {} + mtdKeyTFOpName = "TFOpName" + mtdKeyTFNodeName = "TFNodeName" + + def __init__(self): + self.gamma = {} + self.metadata = {} + self.decls = {} + self.depth = 0 + self.optidict = {} + class Int(ASTNode): - def __init__(self, value: int, bitLen=None, isSecret=True, isScaled=False): - if assertInputTypes: - assert isinstance(value, int) - if bitLen: - assert isinstance(bitLen, int) - assert ((bitLen==32) or (bitLen==64)) - assert isinstance(isSecret, bool) - assert isinstance(isScaled, bool) - super().__init__() - self.value = value - self.bitLen = bitLen - self.isSecret = isSecret - self.isScaled = isScaled + def __init__(self, value: int, bitLen=None, isSecret=True, isScaled=False): + if assertInputTypes: + assert isinstance(value, int) + if bitLen: + assert isinstance(bitLen, int) + assert (bitLen == 32) or (bitLen == 64) + assert isinstance(isSecret, bool) + assert isinstance(isScaled, bool) + super().__init__() + self.value = value + self.bitLen = bitLen + self.isSecret = isSecret + self.isScaled = isScaled + class Float(ASTNode): - def __init__(self, value: float, isSecret=True): - if assertInputTypes: - assert isinstance(value, float) - assert isinstance(isSecret, bool) - super().__init__() - self.value = value - self.isSecret = isSecret + def __init__(self, value: float, isSecret=True): + if assertInputTypes: + assert isinstance(value, float) + assert isinstance(isSecret, bool) + super().__init__() + self.value = value + self.isSecret = isSecret + class ID(ASTNode): - def __init__(self, name: str): - if assertInputTypes: - assert isinstance(name, str) - super().__init__() - self.name = name + def __init__(self, name: str): + if assertInputTypes: + assert isinstance(name, str) + super().__init__() + self.name = name + # shape : list of int, valueList : list of int/float AST Nodes class Decl(ASTNode): - def __init__(self, shape: list, dataType: str, valueList: list, isSecret=True, isScaled=False): - if assertInputTypes: - for elem in shape: assert isinstance(elem, int) - if dataType: - assert isinstance(dataType, str) - if valueList: - for elem in valueList: assert isinstance(elem ,(Int,Float)) - assert(isinstance(isSecret, bool)) - assert(isinstance(isScaled, bool)) - super().__init__() - self.shape = shape - self.dataType = dataType - self.valueList = valueList - self.isSecret = isSecret - self.isScaled = isScaled + def __init__( + self, shape: list, dataType: str, valueList: list, isSecret=True, isScaled=False + ): + if assertInputTypes: + for elem in shape: + assert isinstance(elem, int) + if dataType: + assert isinstance(dataType, str) + if valueList: + for elem in valueList: + assert isinstance(elem, (Int, Float)) + assert isinstance(isSecret, bool) + assert isinstance(isScaled, bool) + super().__init__() + self.shape = shape + self.dataType = dataType + self.valueList = valueList + self.isSecret = isSecret + self.isScaled = isScaled + # expr : ASTNode, perm : list of ints class Transpose(ASTNode): - def __init__(self, expr: ASTNode, perm: list = None): - if assertInputTypes: - assert isinstance(expr, ASTNode) - if perm: - for elem in perm: assert isinstance(elem, int) - super().__init__() - self.expr = expr - self.perm = perm + def __init__(self, expr: ASTNode, perm: list = None): + if assertInputTypes: + assert isinstance(expr, ASTNode) + if perm: + for elem in perm: + assert isinstance(elem, int) + super().__init__() + self.expr = expr + self.perm = perm + # expr : ASTNode, perm : list of ints class Slice(ASTNode): - def __init__(self, expr: ASTNode, subscriptRanges: list = None): - if assertInputTypes: - assert isinstance(expr, ID) - if subscriptRanges: - for elem in subscriptRanges: - assert isinstance(elem[0], int) - assert isinstance(elem[1], int) - super().__init__() - self.expr = expr - self.subscriptRanges = subscriptRanges + def __init__(self, expr: ASTNode, subscriptRanges: list = None): + if assertInputTypes: + assert isinstance(expr, ID) + if subscriptRanges: + for elem in subscriptRanges: + assert isinstance(elem[0], int) + assert isinstance(elem[1], int) + super().__init__() + self.expr = expr + self.subscriptRanges = subscriptRanges + # expr : ASTNode, shape : list of int, order : int : optional class Reshape(ASTNode): - def __init__(self, expr: ASTNode, shape: list, order: list): - if assertInputTypes: - assert isinstance(expr, ASTNode) - for elem in shape: assert isinstance(elem, int) - if order: - for elem in order: assert isinstance(elem, int) - super().__init__() - self.expr = expr - self.shape = shape - self.order = order + def __init__(self, expr: ASTNode, shape: list, order: list): + if assertInputTypes: + assert isinstance(expr, ASTNode) + for elem in shape: + assert isinstance(elem, int) + if order: + for elem in order: + assert isinstance(elem, int) + super().__init__() + self.expr = expr + self.shape = shape + self.order = order + # expr : ASTNode # options : Other options required by maxpool -# Order: [FH, FW, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW] +# Order: [FH, FW, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW] class Pool(ASTNode): - class PoolType: - MaxPool = "MaxPool" - AvgPool = "AvgPool" - - def __init__(self, poolType:str, expr:ASTNode, options:dict): - if assertInputTypes: - assert (poolType==Pool.PoolType.MaxPool or poolType==Pool.PoolType.AvgPool) - assert isinstance(expr, ASTNode) - assert isinstance(options, dict) - assert (PaddingKeysDict.FH in options) - assert (PaddingKeysDict.FW in options) - assert (PaddingKeysDict.zPadHLeft in options) - assert (PaddingKeysDict.zPadHRight in options) - assert (PaddingKeysDict.zPadWLeft in options) - assert (PaddingKeysDict.zPadWRight in options) - assert (PaddingKeysDict.strideH in options) - assert (PaddingKeysDict.strideW in options) - - super().__init__() - self.poolType = poolType - self.expr = expr - self.options = options + class PoolType: + MaxPool = "MaxPool" + AvgPool = "AvgPool" + + def __init__(self, poolType: str, expr: ASTNode, options: dict): + if assertInputTypes: + assert ( + poolType == Pool.PoolType.MaxPool or poolType == Pool.PoolType.AvgPool + ) + assert isinstance(expr, ASTNode) + assert isinstance(options, dict) + assert PaddingKeysDict.FH in options + assert PaddingKeysDict.FW in options + assert PaddingKeysDict.zPadHLeft in options + assert PaddingKeysDict.zPadHRight in options + assert PaddingKeysDict.zPadWLeft in options + assert PaddingKeysDict.zPadWRight in options + assert PaddingKeysDict.strideH in options + assert PaddingKeysDict.strideW in options + + super().__init__() + self.poolType = poolType + self.expr = expr + self.options = options + class UOp(ASTNode): - def __init__(self, op: Operators, expr:ASTNode): - if assertInputTypes: - assert isinstance(op, Operators) - assert isinstance(expr, ASTNode) - super().__init__() - self.op = op - self.expr = expr + def __init__(self, op: Operators, expr: ASTNode): + if assertInputTypes: + assert isinstance(op, Operators) + assert isinstance(expr, ASTNode) + super().__init__() + self.op = op + self.expr = expr + class BOp(ASTNode): - # Options is used to convey extra info if the operator needs so - # For example, it will be useful for convolution to convey strides etc. - - # IMPORTANT NOTE: The options parameter coming for ConvTranspose is for the conv of which it is an inverse - - def __init__(self, expr1: ASTNode, op: Operators, expr2: ASTNode, options=None): - if assertInputTypes: - assert isinstance(expr1, ASTNode) - assert isinstance(op, Operators) - assert isinstance(expr2, ASTNode) - if options: assert isinstance(options, dict) - if op == Operators.CONV or op == Operators.CONVTRANSPOSE: - assert (PaddingKeysDict.FH in options) - assert (PaddingKeysDict.FW in options) - assert (PaddingKeysDict.zPadHLeft in options) - assert (PaddingKeysDict.zPadHRight in options) - assert (PaddingKeysDict.zPadWLeft in options) - assert (PaddingKeysDict.zPadWRight in options) - assert (PaddingKeysDict.strideH in options) - assert (PaddingKeysDict.strideW in options) - if PaddingKeysDict.ConvDim in options: - assert(options[PaddingKeysDict.ConvDim]==2 or options[PaddingKeysDict.ConvDim]==3) - if options[PaddingKeysDict.ConvDim]==3: - #3D conv - assert over the depth dimension - assert (PaddingKeysDict.FD in options) - assert (PaddingKeysDict.zPadDLeft in options) - assert (PaddingKeysDict.zPadDRight in options) - assert (PaddingKeysDict.strideD in options) - if op == Operators.CONVTRANSPOSE: - # In addition if this op is convTranspose, then - # the output size should also be specified - assert(PaddingKeysDict.outputImgH in options) - assert(PaddingKeysDict.outputImgW in options) - if (PaddingKeysDict.ConvDim in options) and (options[PaddingKeysDict.ConvDim]==3): - assert(PaddingKeysDict.outputImgD in options) - super().__init__() - self.expr1 = expr1 - self.op = op - self.expr2 = expr2 - self.options = options + # Options is used to convey extra info if the operator needs so + # For example, it will be useful for convolution to convey strides etc. + + # IMPORTANT NOTE: The options parameter coming for ConvTranspose is for the conv of which it is an inverse + + def __init__(self, expr1: ASTNode, op: Operators, expr2: ASTNode, options=None): + if assertInputTypes: + assert isinstance(expr1, ASTNode) + assert isinstance(op, Operators) + assert isinstance(expr2, ASTNode) + if options: + assert isinstance(options, dict) + if op == Operators.CONV or op == Operators.CONVTRANSPOSE: + assert PaddingKeysDict.FH in options + assert PaddingKeysDict.FW in options + assert PaddingKeysDict.zPadHLeft in options + assert PaddingKeysDict.zPadHRight in options + assert PaddingKeysDict.zPadWLeft in options + assert PaddingKeysDict.zPadWRight in options + assert PaddingKeysDict.strideH in options + assert PaddingKeysDict.strideW in options + if PaddingKeysDict.ConvDim in options: + assert ( + options[PaddingKeysDict.ConvDim] == 2 + or options[PaddingKeysDict.ConvDim] == 3 + ) + if options[PaddingKeysDict.ConvDim] == 3: + # 3D conv - assert over the depth dimension + assert PaddingKeysDict.FD in options + assert PaddingKeysDict.zPadDLeft in options + assert PaddingKeysDict.zPadDRight in options + assert PaddingKeysDict.strideD in options + if op == Operators.CONVTRANSPOSE: + # In addition if this op is convTranspose, then + # the output size should also be specified + assert PaddingKeysDict.outputImgH in options + assert PaddingKeysDict.outputImgW in options + if (PaddingKeysDict.ConvDim in options) and ( + options[PaddingKeysDict.ConvDim] == 3 + ): + assert PaddingKeysDict.outputImgD in options + super().__init__() + self.expr1 = expr1 + self.op = op + self.expr2 = expr2 + self.options = options + class Func(ASTNode): - def __init__(self, op: Operators, expr: ASTNode): - if assertInputTypes: - assert isinstance(op, Operators) - assert isinstance(expr, ASTNode) - super().__init__() - self.op = op - self.expr = expr + def __init__(self, op: Operators, expr: ASTNode): + if assertInputTypes: + assert isinstance(op, Operators) + assert isinstance(expr, ASTNode) + super().__init__() + self.op = op + self.expr = expr + class Let(ASTNode): - def __init__(self, name: ID, decl: ASTNode, expr: ASTNode): - if assertInputTypes: - assert isinstance(name, ID) - assert isinstance(decl, ASTNode) - assert isinstance(expr, ASTNode) - super().__init__() - self.name = name - self.decl = decl - self.expr = expr + def __init__(self, name: ID, decl: ASTNode, expr: ASTNode): + if assertInputTypes: + assert isinstance(name, ID) + assert isinstance(decl, ASTNode) + assert isinstance(expr, ASTNode) + super().__init__() + self.name = name + self.decl = decl + self.expr = expr + # Assumption is that the output of this is always a tensor # outputShape : list of int, funcName : str, argsList : list of ASTNodes # isSecret : whether the output of this node is public or secret # outputDiffInpDims = 0 => output only different input dims -# = 1 => always output input dims -# = 2 => never output input dims -# : NOTE this doesn't apply for function names +# = 1 => always output input dims +# = 2 => never output input dims +# : NOTE this doesn't apply for function names class UninterpFuncCall(ASTNode): - def __init__(self, outputShape: list, funcName: str, argsList: list, isSecret=True, outputDiffInpDims=0): - if assertInputTypes: - for elem in outputShape: assert isinstance(elem, int) - assert isinstance(funcName, str) - for arg in argsList: assert isinstance(arg, ASTNode) - assert isinstance(isSecret, bool) - assert isinstance(outputDiffInpDims, int) - super().__init__() - self.outputShape = outputShape - self.funcName = funcName - self.argsList = argsList - self.isSecret = isSecret - self.outputDiffInpDims = outputDiffInpDims + def __init__( + self, + outputShape: list, + funcName: str, + argsList: list, + isSecret=True, + outputDiffInpDims=0, + ): + if assertInputTypes: + for elem in outputShape: + assert isinstance(elem, int) + assert isinstance(funcName, str) + for arg in argsList: + assert isinstance(arg, ASTNode) + assert isinstance(isSecret, bool) + assert isinstance(outputDiffInpDims, int) + super().__init__() + self.outputShape = outputShape + self.funcName = funcName + self.argsList = argsList + self.isSecret = isSecret + self.outputDiffInpDims = outputDiffInpDims + class ArgMax(ASTNode): - def __init__(self, outputShape: list, expr: ID, dim: ASTNode, inShape: list): - if assertInputTypes: - for elem in outputShape: assert isinstance(elem, int) - assert isinstance(expr, ID) - assert isinstance(dim, ASTNode) - for elem in inShape: assert isinstance(elem, int) - super().__init__() - self.outputShape = outputShape - self.expr = expr - self.dim = dim - self.inShape = inShape + def __init__(self, outputShape: list, expr: ID, dim: ASTNode, inShape: list): + if assertInputTypes: + for elem in outputShape: + assert isinstance(elem, int) + assert isinstance(expr, ID) + assert isinstance(dim, ASTNode) + for elem in inShape: + assert isinstance(elem, int) + super().__init__() + self.outputShape = outputShape + self.expr = expr + self.dim = dim + self.inShape = inShape + class Reduce(ASTNode): - def __init__(self, expr:ID, keepdims:bool, outShape:list, op: Operators, reductionAxesList: list): - # keepdims is unused for now - if assertInputTypes: - assert isinstance(expr, ID) - assert isinstance(keepdims, bool) - assert isinstance(outShape, list) - for elem in outShape: assert isinstance(elem, int) - assert isinstance(op, Operators) - super().__init__() - self.expr = expr - self.keepdims = keepdims - self.outShape = outShape - self.op = op - self.reductionAxesList = reductionAxesList + def __init__( + self, + expr: ID, + keepdims: bool, + outShape: list, + op: Operators, + reductionAxesList: list, + ): + # keepdims is unused for now + if assertInputTypes: + assert isinstance(expr, ID) + assert isinstance(keepdims, bool) + assert isinstance(outShape, list) + for elem in outShape: + assert isinstance(elem, int) + assert isinstance(op, Operators) + super().__init__() + self.expr = expr + self.keepdims = keepdims + self.outShape = outShape + self.op = op + self.reductionAxesList = reductionAxesList + # shape : list of int, dataType : ID -# NOTE: Though datatype is being passed to this function, the output code eventually only has -# int in the apt bitlen for which the whole compilation is done +# NOTE: Though datatype is being passed to this function, the output code eventually only has +# int in the apt bitlen for which the whole compilation is done # Also, take note of the last parameter - "inputByParty". This can be used to set the party which -# which will do the input for this variable. Defaults to SERVER. +# which will do the input for this variable. Defaults to SERVER. class Input(ASTNode): - def __init__(self, shape:list, dataType:str, isSecret=True, inputByParty=Party.SERVER): - if assertInputTypes: - for elem in shape: assert isinstance(elem, int) - assert isinstance(dataType, str) - assert isinstance(inputByParty, Party) - assert(inputByParty==Party.CLIENT or inputByParty==Party.SERVER) #Right now EzPC supports input by two parties. - super().__init__() - self.shape = shape - self.dataType = dataType - self.isSecret = isSecret - self.inputByParty = inputByParty + def __init__( + self, shape: list, dataType: str, isSecret=True, inputByParty=Party.SERVER + ): + if assertInputTypes: + for elem in shape: + assert isinstance(elem, int) + assert isinstance(dataType, str) + assert isinstance(inputByParty, Party) + assert ( + inputByParty == Party.CLIENT or inputByParty == Party.SERVER + ) # Right now EzPC supports input by two parties. + super().__init__() + self.shape = shape + self.dataType = dataType + self.isSecret = isSecret + self.inputByParty = inputByParty + # Since some optimizations are possible around batchnorm, keep this as an interpreted node class FusedBatchNorm(ASTNode): - def __init__(self, expr:ID, multExpr:ID, addExpr:ID): - if assertInputTypes: - assert isinstance(expr, ID) - assert isinstance(multExpr, ID) - assert isinstance(addExpr, ID) - super().__init__() - self.expr = expr - self.multExpr = multExpr - self.addExpr = addExpr - + def __init__(self, expr: ID, multExpr: ID, addExpr: ID): + if assertInputTypes: + assert isinstance(expr, ID) + assert isinstance(multExpr, ID) + assert isinstance(addExpr, ID) + super().__init__() + self.expr = expr + self.multExpr = multExpr + self.addExpr = addExpr diff --git a/Athos/SeeDot/AST/ASTVisitor.py b/Athos/SeeDot/AST/ASTVisitor.py index 03f04ad3..81835f84 100644 --- a/Athos/SeeDot/AST/ASTVisitor.py +++ b/Athos/SeeDot/AST/ASTVisitor.py @@ -1,4 +1,4 @@ -''' +""" Authors: Sridhar Gopinath, Nishant Kumar. @@ -20,109 +20,110 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import AST.AST as AST + class ASTVisitor: - def visitInt(self, node:AST.Int, args=None): - pass - - def visitFloat(self, node:AST.Float, args=None): - pass - - def visitId(self, node:AST.ID, args=None): - pass - - def visitDecl(self, node:AST.Decl, args=None): - if node.valueList: - for elem in node.valueList: - self.visit(elem, args) - - def visitTranspose(self, node:AST.Transpose, args=None): - self.visit(node.expr, args) - - def visitSlice(self, node:AST.Slice, args=None): - self.visit(node.expr, args) - - def visitReshape(self, node:AST.Reshape, args=None): - self.visit(node.expr, args) - - def visitPool(self, node:AST.Pool, args=None): - self.visit(node.expr, args) - - def visitUOp(self, node:AST.UOp, args=None): - self.visit(node.expr, args) - - def visitBOp(self, node:AST.BOp, args=None): - self.visit(node.expr1, args) - self.visit(node.expr2, args) - - def visitFunc(self, node:AST.Func, args=None): - self.visit(node.expr, args) - - def visitLet(self, node:AST.Let, args=None): - self.visit(node.name, args) - self.visit(node.decl, args) - self.visit(node.expr, args) - - def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args=None): - for elem in node.argsList: - self.visit(elem, args) - - def visitArgMax(self, node:AST.ArgMax, args=None): - self.visit(node.expr, args) - self.visit(node.dim, args) - - def visitReduce(self, node:AST.Reduce, args=None): - self.visit(node.expr, args) - - def visitInput(self, node:AST.Input, args=None): - pass - - def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args=None): - self.visit(node.expr, args) - self.visit(node.multExpr, args) - self.visit(node.addExpr, args) - - def visit(self, node, args=None): - if node is None: - return - if isinstance(node, AST.Int): - return self.visitInt(node, args) - elif isinstance(node, AST.Float): - return self.visitFloat(node, args) - elif isinstance(node, AST.ID): - return self.visitId(node, args) - elif isinstance(node, AST.Decl): - return self.visitDecl(node, args) - elif isinstance(node, AST.Transpose): - return self.visitTranspose(node, args) - elif isinstance(node, AST.Slice): - return self.visitSlice(node, args) - elif isinstance(node, AST.Reshape): - return self.visitReshape(node, args) - elif isinstance(node, AST.Pool): - return self.visitPool(node, args) - elif isinstance(node, AST.UOp): - return self.visitUOp(node, args) - elif isinstance(node, AST.BOp): - return self.visitBOp(node, args) - elif isinstance(node, AST.Func): - return self.visitFunc(node, args) - elif isinstance(node, AST.Let): - return self.visitLet(node, args) - elif isinstance(node, AST.UninterpFuncCall): - return self.visitUninterpFuncCall(node, args) - elif isinstance(node, AST.ArgMax): - return self.visitArgMax(node, args) - elif isinstance(node, AST.Reduce): - return self.visitReduce(node, args) - elif isinstance(node, AST.Input): - return self.visitInput(node, args) - elif isinstance(node, AST.FusedBatchNorm): - return self.visitFusedBatchNorm(node, args) - elif node: - raise Exception('Node instance not matched.') - else: - pass + def visitInt(self, node: AST.Int, args=None): + pass + + def visitFloat(self, node: AST.Float, args=None): + pass + + def visitId(self, node: AST.ID, args=None): + pass + + def visitDecl(self, node: AST.Decl, args=None): + if node.valueList: + for elem in node.valueList: + self.visit(elem, args) + + def visitTranspose(self, node: AST.Transpose, args=None): + self.visit(node.expr, args) + + def visitSlice(self, node: AST.Slice, args=None): + self.visit(node.expr, args) + + def visitReshape(self, node: AST.Reshape, args=None): + self.visit(node.expr, args) + + def visitPool(self, node: AST.Pool, args=None): + self.visit(node.expr, args) + + def visitUOp(self, node: AST.UOp, args=None): + self.visit(node.expr, args) + + def visitBOp(self, node: AST.BOp, args=None): + self.visit(node.expr1, args) + self.visit(node.expr2, args) + + def visitFunc(self, node: AST.Func, args=None): + self.visit(node.expr, args) + + def visitLet(self, node: AST.Let, args=None): + self.visit(node.name, args) + self.visit(node.decl, args) + self.visit(node.expr, args) + + def visitUninterpFuncCall(self, node: AST.UninterpFuncCall, args=None): + for elem in node.argsList: + self.visit(elem, args) + + def visitArgMax(self, node: AST.ArgMax, args=None): + self.visit(node.expr, args) + self.visit(node.dim, args) + + def visitReduce(self, node: AST.Reduce, args=None): + self.visit(node.expr, args) + + def visitInput(self, node: AST.Input, args=None): + pass + + def visitFusedBatchNorm(self, node: AST.FusedBatchNorm, args=None): + self.visit(node.expr, args) + self.visit(node.multExpr, args) + self.visit(node.addExpr, args) + + def visit(self, node, args=None): + if node is None: + return + if isinstance(node, AST.Int): + return self.visitInt(node, args) + elif isinstance(node, AST.Float): + return self.visitFloat(node, args) + elif isinstance(node, AST.ID): + return self.visitId(node, args) + elif isinstance(node, AST.Decl): + return self.visitDecl(node, args) + elif isinstance(node, AST.Transpose): + return self.visitTranspose(node, args) + elif isinstance(node, AST.Slice): + return self.visitSlice(node, args) + elif isinstance(node, AST.Reshape): + return self.visitReshape(node, args) + elif isinstance(node, AST.Pool): + return self.visitPool(node, args) + elif isinstance(node, AST.UOp): + return self.visitUOp(node, args) + elif isinstance(node, AST.BOp): + return self.visitBOp(node, args) + elif isinstance(node, AST.Func): + return self.visitFunc(node, args) + elif isinstance(node, AST.Let): + return self.visitLet(node, args) + elif isinstance(node, AST.UninterpFuncCall): + return self.visitUninterpFuncCall(node, args) + elif isinstance(node, AST.ArgMax): + return self.visitArgMax(node, args) + elif isinstance(node, AST.Reduce): + return self.visitReduce(node, args) + elif isinstance(node, AST.Input): + return self.visitInput(node, args) + elif isinstance(node, AST.FusedBatchNorm): + return self.visitFusedBatchNorm(node, args) + elif node: + raise Exception("Node instance not matched.") + else: + pass diff --git a/Athos/SeeDot/AST/IRBuilderAST.py b/Athos/SeeDot/AST/IRBuilderAST.py index 1d27e16b..10b6aaf6 100644 --- a/Athos/SeeDot/AST/IRBuilderAST.py +++ b/Athos/SeeDot/AST/IRBuilderAST.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,16 +20,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import AST.AST as AST from AST.ASTVisitor import ASTVisitor + class IRBuilderAST(ASTVisitor): - typeInfo = {} - def visit(self, node, args=None): - ret = super().visit( node, args) - if type(ret) is tuple: - if ret[1].idf not in self.typeInfo: - self.typeInfo[ret[1].idf] = node.type - return ret \ No newline at end of file + typeInfo = {} + + def visit(self, node, args=None): + ret = super().visit(node, args) + if type(ret) is tuple: + if ret[1].idf not in self.typeInfo: + self.typeInfo[ret[1].idf] = node.type + return ret diff --git a/Athos/SeeDot/AST/MtdAST.py b/Athos/SeeDot/AST/MtdAST.py index ef9a4102..27c9d938 100644 --- a/Athos/SeeDot/AST/MtdAST.py +++ b/Athos/SeeDot/AST/MtdAST.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,77 +20,78 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import AST.AST as AST from AST.ASTVisitor import ASTVisitor + class MtdAST(ASTVisitor): - def visitInt(self, node:AST.Int, mtd:dict): - node.metadata.update(mtd) - - def visitFloat(self, node:AST.Float, mtd:dict): - node.metadata.update(mtd) - - def visitId(self, node:AST.ID, mtd:dict): - node.metadata.update(mtd) - - def visitDecl(self, node:AST.Decl, mtd:dict): - node.metadata.update(mtd) - - def visitTranspose(self, node:AST.Transpose, mtd:dict): - node.metadata.update(mtd) - self.visit(node.expr, mtd) - - def visitSlice(self, node:AST.Slice, mtd:dict): - node.metadata.update(mtd) - self.visit(node.expr, mtd) - - def visitReshape(self, node:AST.Reshape, mtd:dict): - node.metadata.update(mtd) - self.visit(node.expr, mtd) - - def visitPool(self, node:AST.Pool, mtd:dict): - node.metadata.update(mtd) - self.visit(node.expr, mtd) - - def visitUOp(self, node:AST.UOp, mtd:dict): - node.metadata.update(mtd) - self.visit(node.expr, mtd) - - def visitBOp(self, node:AST.BOp, mtd:dict): - node.metadata.update(mtd) - self.visit(node.expr1, mtd) - self.visit(node.expr2, mtd) - - def visitFunc(self, node:AST.Func, mtd:dict): - node.metadata.update(mtd) - self.visit(node.expr, mtd) - - def visitLet(self, node:AST.Let, mtd:dict): - node.metadata.update(mtd) - self.visit(node.name, mtd) - self.visit(node.decl, mtd) - self.visit(node.expr, mtd) - - def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, mtd:dict): - node.metadata.update(mtd) - for curArg in node.argsList: - self.visit(curArg, mtd) - - def visitArgMax(self, node:AST.ArgMax, mtd:dict): - node.metadata.update(mtd) - self.visit(node.expr, mtd) - - def visitReduce(self, node:AST.Reduce, mtd:dict): - node.metadata.update(mtd) - self.visit(node.expr, mtd) - - def visitInput(self, node:AST.Input, mtd:dict): - node.metadata.update(mtd) - - def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, mtd:dict): - node.metadata.update(mtd) - self.visit(node.expr, mtd) - self.visit(node.multExpr, mtd) - self.visit(node.addExpr, mtd) + def visitInt(self, node: AST.Int, mtd: dict): + node.metadata.update(mtd) + + def visitFloat(self, node: AST.Float, mtd: dict): + node.metadata.update(mtd) + + def visitId(self, node: AST.ID, mtd: dict): + node.metadata.update(mtd) + + def visitDecl(self, node: AST.Decl, mtd: dict): + node.metadata.update(mtd) + + def visitTranspose(self, node: AST.Transpose, mtd: dict): + node.metadata.update(mtd) + self.visit(node.expr, mtd) + + def visitSlice(self, node: AST.Slice, mtd: dict): + node.metadata.update(mtd) + self.visit(node.expr, mtd) + + def visitReshape(self, node: AST.Reshape, mtd: dict): + node.metadata.update(mtd) + self.visit(node.expr, mtd) + + def visitPool(self, node: AST.Pool, mtd: dict): + node.metadata.update(mtd) + self.visit(node.expr, mtd) + + def visitUOp(self, node: AST.UOp, mtd: dict): + node.metadata.update(mtd) + self.visit(node.expr, mtd) + + def visitBOp(self, node: AST.BOp, mtd: dict): + node.metadata.update(mtd) + self.visit(node.expr1, mtd) + self.visit(node.expr2, mtd) + + def visitFunc(self, node: AST.Func, mtd: dict): + node.metadata.update(mtd) + self.visit(node.expr, mtd) + + def visitLet(self, node: AST.Let, mtd: dict): + node.metadata.update(mtd) + self.visit(node.name, mtd) + self.visit(node.decl, mtd) + self.visit(node.expr, mtd) + + def visitUninterpFuncCall(self, node: AST.UninterpFuncCall, mtd: dict): + node.metadata.update(mtd) + for curArg in node.argsList: + self.visit(curArg, mtd) + + def visitArgMax(self, node: AST.ArgMax, mtd: dict): + node.metadata.update(mtd) + self.visit(node.expr, mtd) + + def visitReduce(self, node: AST.Reduce, mtd: dict): + node.metadata.update(mtd) + self.visit(node.expr, mtd) + + def visitInput(self, node: AST.Input, mtd: dict): + node.metadata.update(mtd) + + def visitFusedBatchNorm(self, node: AST.FusedBatchNorm, mtd: dict): + node.metadata.update(mtd) + self.visit(node.expr, mtd) + self.visit(node.multExpr, mtd) + self.visit(node.addExpr, mtd) diff --git a/Athos/SeeDot/AST/PrintAST.py b/Athos/SeeDot/AST/PrintAST.py index 1ef915da..ccc4b7c0 100644 --- a/Athos/SeeDot/AST/PrintAST.py +++ b/Athos/SeeDot/AST/PrintAST.py @@ -1,4 +1,4 @@ -''' +""" Authors: Sridhar Gopinath, Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import AST.AST as AST from AST.ASTVisitor import ASTVisitor @@ -28,105 +28,129 @@ indent = "" -class PrintAST(ASTVisitor): - #TODO : fix printing of AST - def visitInt(self, node:AST.Int, args=None): - print(indent * node.depth, node.value, end=' ') - - def visitFloat(self, node:AST.Float, args=None): - print(indent * node.depth, node.value, end=' ') - - def visitId(self, node:AST.ID, args=None): - print(indent * node.depth, node.name, end=' ') - - def visitDecl(self, node:AST.Decl, args=None): - if node.valueList: - print(indent * node.depth, node.shape, list(map(lambda x: x.value, node.valueList)), end=' ') - else: - print(indent * node.depth, node.shape, end=' ') - - def visitTranspose(self, node:AST.Transpose, args=None): - node.expr.depth = node.depth + 1 - print(indent * node.depth, end=' ') - self.visit(node.expr) - print("^Transpose", end=' ') - - def visitSlice(self, node:AST.Transpose, args=None): - node.expr.depth = node.depth + 1 - print(indent * node.depth, end=' ') - self.visit(node.expr) - print("extract slice", end=' ') - - def visitReshape(self, node:AST.Reshape, args=None): - node.expr.depth = node.depth + 1 - print(indent * node.depth, "reshape", end=' ') - self.visit(node.expr) - if (node.order): - print(node.shape, "order", node.order, end=' ') - else: - print(node.shape, end=' ') - - def visitPool(self, node:AST.Pool, args=None): - node.expr.depth = node.depth + 1 - print(indent * node.depth, node.poolType, end=' ') - self.visit(node.expr) - - def visitUOp(self, node:AST.UOp, args=None): - node.expr.depth = node.depth + 1 - print(indent * node.depth, AST.OperatorsSymbolDict[node.op.name], end=' ') - self.visit(node.expr) - - def visitBOp(self, node:AST.BOp, args=None): - node.expr1.depth = node.expr2.depth = node.depth + 1 - print(indent * node.depth, end=' ') - self.visit(node.expr1) - print(AST.OperatorsSymbolDict[node.op.name], end=' ') - self.visit(node.expr2) - - def visitFunc(self, node:AST.Func, args=None): - print(indent * node.depth, AST.OperatorsSymbolDict[node.op.name], end=' ') - node.expr.depth = node.depth + 1 - self.visit(node.expr) - - def visitLet(self, node:AST.Let, args=None): - if (node.decl is not None): - node.decl.depth = node.depth + 1 - if (node.expr is not None): - node.expr.depth = node.depth + 1 - print(indent * node.depth, "(", end=' ') - print("let", end=' ') - if(hasattr(node.name, 'type') and hasattr(node.name.type, 'taint')): - print("<", node.decl.type.taint.name, ">",end=' ') - self.visit(node.name) - print("=", end=' ') - self.visit(node.decl) - print("{", node.metadata[AST.ASTNode.mtdKeyTFOpName], node.metadata[AST.ASTNode.mtdKeyTFNodeName], "} in ", end='\n') - self.visit(node.expr) - print(')',end='') - - def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args=None): - print(indent * node.depth, "UninterpFuncCall", node.funcName, end=' ') - for x in node.argsList: - self.visit(x) - - def visitArgMax(self, node:AST.ArgMax, args=None): - print(indent * node.depth, "ArgMax", end=' ') - self.visit(node.expr) - self.visit(node.dim) - - def visitReduce(self, node:AST.Reduce, args=None): - print(indent * node.depth, "reduce", AST.OperatorsSymbolDict[node.op.name], end=' ') - self.visit(node.expr) - - def visitInput(self, node:AST.Input, args=None): - print(indent * node.depth, "input( ", node.shape, node.dataType, " <", node.inputByParty.name, "> ", end='') - print(" )", end='') - - def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args=None): - node.expr.depth = node.multExpr.depth = node.addExpr.depth = node.depth + 1 - print(indent * node.depth, "FusedBatchNorm", end=' ') - self.visit(node.expr) - self.visit(node.multExpr) - self.visit(node.addExpr) - +class PrintAST(ASTVisitor): + # TODO : fix printing of AST + def visitInt(self, node: AST.Int, args=None): + print(indent * node.depth, node.value, end=" ") + + def visitFloat(self, node: AST.Float, args=None): + print(indent * node.depth, node.value, end=" ") + + def visitId(self, node: AST.ID, args=None): + print(indent * node.depth, node.name, end=" ") + + def visitDecl(self, node: AST.Decl, args=None): + if node.valueList: + print( + indent * node.depth, + node.shape, + list(map(lambda x: x.value, node.valueList)), + end=" ", + ) + else: + print(indent * node.depth, node.shape, end=" ") + + def visitTranspose(self, node: AST.Transpose, args=None): + node.expr.depth = node.depth + 1 + print(indent * node.depth, end=" ") + self.visit(node.expr) + print("^Transpose", end=" ") + + def visitSlice(self, node: AST.Transpose, args=None): + node.expr.depth = node.depth + 1 + print(indent * node.depth, end=" ") + self.visit(node.expr) + print("extract slice", end=" ") + + def visitReshape(self, node: AST.Reshape, args=None): + node.expr.depth = node.depth + 1 + print(indent * node.depth, "reshape", end=" ") + self.visit(node.expr) + if node.order: + print(node.shape, "order", node.order, end=" ") + else: + print(node.shape, end=" ") + + def visitPool(self, node: AST.Pool, args=None): + node.expr.depth = node.depth + 1 + print(indent * node.depth, node.poolType, end=" ") + self.visit(node.expr) + + def visitUOp(self, node: AST.UOp, args=None): + node.expr.depth = node.depth + 1 + print(indent * node.depth, AST.OperatorsSymbolDict[node.op.name], end=" ") + self.visit(node.expr) + + def visitBOp(self, node: AST.BOp, args=None): + node.expr1.depth = node.expr2.depth = node.depth + 1 + print(indent * node.depth, end=" ") + self.visit(node.expr1) + print(AST.OperatorsSymbolDict[node.op.name], end=" ") + self.visit(node.expr2) + + def visitFunc(self, node: AST.Func, args=None): + print(indent * node.depth, AST.OperatorsSymbolDict[node.op.name], end=" ") + node.expr.depth = node.depth + 1 + self.visit(node.expr) + + def visitLet(self, node: AST.Let, args=None): + if node.decl is not None: + node.decl.depth = node.depth + 1 + if node.expr is not None: + node.expr.depth = node.depth + 1 + print(indent * node.depth, "(", end=" ") + print("let", end=" ") + if hasattr(node.name, "type") and hasattr(node.name.type, "taint"): + print("<", node.decl.type.taint.name, ">", end=" ") + self.visit(node.name) + print("=", end=" ") + self.visit(node.decl) + print( + "{", + node.metadata[AST.ASTNode.mtdKeyTFOpName], + node.metadata[AST.ASTNode.mtdKeyTFNodeName], + "} in ", + end="\n", + ) + self.visit(node.expr) + print(")", end="") + + def visitUninterpFuncCall(self, node: AST.UninterpFuncCall, args=None): + print(indent * node.depth, "UninterpFuncCall", node.funcName, end=" ") + for x in node.argsList: + self.visit(x) + + def visitArgMax(self, node: AST.ArgMax, args=None): + print(indent * node.depth, "ArgMax", end=" ") + self.visit(node.expr) + self.visit(node.dim) + + def visitReduce(self, node: AST.Reduce, args=None): + print( + indent * node.depth, + "reduce", + AST.OperatorsSymbolDict[node.op.name], + end=" ", + ) + self.visit(node.expr) + + def visitInput(self, node: AST.Input, args=None): + print( + indent * node.depth, + "input( ", + node.shape, + node.dataType, + " <", + node.inputByParty.name, + "> ", + end="", + ) + print(" )", end="") + + def visitFusedBatchNorm(self, node: AST.FusedBatchNorm, args=None): + node.expr.depth = node.multExpr.depth = node.addExpr.depth = node.depth + 1 + print(indent * node.depth, "FusedBatchNorm", end=" ") + self.visit(node.expr) + self.visit(node.multExpr) + self.visit(node.addExpr) diff --git a/Athos/SeeDot/Codegen/CodegenBase.py b/Athos/SeeDot/Codegen/CodegenBase.py index a0da78a7..1e94f857 100644 --- a/Athos/SeeDot/Codegen/CodegenBase.py +++ b/Athos/SeeDot/Codegen/CodegenBase.py @@ -1,4 +1,4 @@ -''' +""" Authors: Sridhar Gopinath, Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import numpy as np from enum import Enum @@ -29,199 +29,200 @@ import Type as Type from Util import * + class CodegenBase: - #TODO : Clean this file of extra info - def __init__(self, writer): - self.out = writer - - def printOp(self, ir): - self.out.printf('%s', ir.name) - - def printVar(self, ir): - self.out.printf('%s', ir.idf) - for e in ir.idx: - self.out.printf('[') - self.print(e) - self.out.printf(']') - - def printBool(self, ir): - self.out.printf({True:'true', False:'false'}[ir.b]) - - def printIntUop(self, ir): - self.out.printf('(') - self.print(ir.op) - self.print(ir.e) - self.out.printf(')') - - def printTypeCast(self, ir): - self.out.printf('(') - self.out.printf('(' + ir.type + ')') - self.print(ir.expr) - self.out.printf(')') - - def printIntBop(self, ir): - self.out.printf('(') - self.print(ir.e1) - self.out.printf(' ') - self.print(ir.op) - self.out.printf(' ') - self.print(ir.e2) - self.out.printf(')') - - def printBoolUop(self, ir): - self.out.printf('(') - self.print(ir.op) - self.print(ir.e) - self.out.printf(')') - - def printBoolBop(self, ir): - self.out.printf('(') - self.print(ir.e1) - self.out.printf(' ') - self.print(ir.op) - self.out.printf(' ') - self.print(ir.e2) - self.out.printf(')') - - def printBoolCop(self, ir): - self.out.printf('(') - self.print(ir.e1) - self.out.printf(' ') - self.print(ir.op) - self.out.printf(' ') - self.print(ir.e2) - self.out.printf(')') - - def printCExpr(self, ir): - self.out.printf('(') - self.print(ir.cond) - self.out.printf(' ? ') - self.print(ir.et) - self.out.printf(' : ') - self.print(ir.ef) - self.out.printf(')') - - def printAssn(self, ir): - self.out.printf('', indent=True) - self.print(ir.var) - self.out.printf(' = ') - self.print(ir.e) - self.out.printf(';\n') - - def printIf(self, ir): - self.out.printf('if (', indent = True) - self.print(ir.cond) - self.out.printf(') {\n') - - self.out.increaseIndent() - for cmd in ir.trueCmds: - self.print(cmd) - self.out.decreaseIndent() - - if len(ir.falseCmds) == 0: - self.out.printf('}\n', indent=True) - return - - self.out.printf('} else {\n', indent=True) - - self.out.increaseIndent() - for cmd in ir.falseCmds: - self.print(cmd) - self.out.decreaseIndent() - - self.out.printf('}\n', indent=True) - - def printFor(self, ir): - self.printForHeader(ir) - self.out.increaseIndent() - for cmd in ir.cmd_l: - self.print(cmd) - self.out.decreaseIndent() - self.out.printf('}\n', indent=True) - - def printForHeader(self, ir): - self.out.printf('for (%s ', IR.DataType.getIntStr(), indent=True) - self.print(ir.var) - self.out.printf(' = %d; ', ir.st) - self.print(ir.cond) - self.out.printf('; ') - self.print(ir.var) - self.out.printf('++) {\n') - - def printWhile(self, ir): - self.out.printf('while (', indent=True) - self.print(ir.expr) - self.out.printf(') {\n') - self.out.increaseIndent() - for cmd in ir.cmds: - self.print(cmd) - self.out.decreaseIndent() - self.out.printf('}\n', indent=True) - - def printProg(self, ir): - for cmd in ir.cmd_l: - self.print(cmd) - - def printPrint(self, ir): - self.out.printf('cout << ', indent=True) - self.print(ir.expr) - self.out.printf(' << endl;\n') - - def printPrintAsFloat(self, ir): - self.out.printf('cout << ((float)(', indent=True) - self.print(ir.expr) - self.out.printf(')) * ' + str(2 ** ir.expnt) + ' << "";\n') - - def print(self, ir): - if isinstance(ir, IR.Int): - return self.printInt(ir) - elif isinstance(ir, IR.Var): - return self.printVar(ir) - elif isinstance(ir, IR.Bool): - return self.printBool(ir) - elif isinstance(ir, IR.IntUop): - return self.printIntUop(ir) - elif isinstance(ir, IR.Exp): - return self.printExp(ir) - elif isinstance(ir, IR.TypeCast): - return self.printTypeCast(ir) - elif isinstance(ir, IR.IntBop): - return self.printIntBop(ir) - elif isinstance(ir, IR.BoolUop): - return self.printBoolUop(ir) - elif isinstance(ir, IR.BoolBop): - return self.printBoolBop(ir) - elif isinstance(ir, IR.BoolCop): - return self.printBoolCop(ir) - elif isinstance(ir, IR.CExpr): - return self.printCExpr(ir) - elif isinstance(ir, IR.Assn): - return self.printAssn(ir) - elif isinstance(ir, IR.If): - return self.printIf(ir) - elif isinstance(ir, IR.For): - return self.printFor(ir) - elif isinstance(ir, IR.While): - return self.printWhile(ir) - elif isinstance(ir, IR.Comment): - return self.printComment(ir) - elif isinstance(ir, IR.Pragmas): - return self.printPragmas(ir) - elif isinstance(ir, IR.Prog): - return self.printProg(ir) - elif isinstance(ir, IR.Memset): - return self.printMemset(ir) - elif isinstance(ir, IR.Print): - return self.printPrint(ir) - elif isinstance(ir, IR.PrintAsFloat): - return self.printPrintAsFloat(ir) - elif isinstance(ir, IR.FuncCall): - return self.printFuncCall(ir) - elif isinstance(ir, IR.Op.Op): - return self.printOp(ir) - elif isinstance(ir, IR.Input): - return self.printInput(ir) - elif isinstance(ir, IR.Decl): - return self.printDecl(ir) - else: - assert False + # TODO : Clean this file of extra info + def __init__(self, writer): + self.out = writer + + def printOp(self, ir): + self.out.printf("%s", ir.name) + + def printVar(self, ir): + self.out.printf("%s", ir.idf) + for e in ir.idx: + self.out.printf("[") + self.print(e) + self.out.printf("]") + + def printBool(self, ir): + self.out.printf({True: "true", False: "false"}[ir.b]) + + def printIntUop(self, ir): + self.out.printf("(") + self.print(ir.op) + self.print(ir.e) + self.out.printf(")") + + def printTypeCast(self, ir): + self.out.printf("(") + self.out.printf("(" + ir.type + ")") + self.print(ir.expr) + self.out.printf(")") + + def printIntBop(self, ir): + self.out.printf("(") + self.print(ir.e1) + self.out.printf(" ") + self.print(ir.op) + self.out.printf(" ") + self.print(ir.e2) + self.out.printf(")") + + def printBoolUop(self, ir): + self.out.printf("(") + self.print(ir.op) + self.print(ir.e) + self.out.printf(")") + + def printBoolBop(self, ir): + self.out.printf("(") + self.print(ir.e1) + self.out.printf(" ") + self.print(ir.op) + self.out.printf(" ") + self.print(ir.e2) + self.out.printf(")") + + def printBoolCop(self, ir): + self.out.printf("(") + self.print(ir.e1) + self.out.printf(" ") + self.print(ir.op) + self.out.printf(" ") + self.print(ir.e2) + self.out.printf(")") + + def printCExpr(self, ir): + self.out.printf("(") + self.print(ir.cond) + self.out.printf(" ? ") + self.print(ir.et) + self.out.printf(" : ") + self.print(ir.ef) + self.out.printf(")") + + def printAssn(self, ir): + self.out.printf("", indent=True) + self.print(ir.var) + self.out.printf(" = ") + self.print(ir.e) + self.out.printf(";\n") + + def printIf(self, ir): + self.out.printf("if (", indent=True) + self.print(ir.cond) + self.out.printf(") {\n") + + self.out.increaseIndent() + for cmd in ir.trueCmds: + self.print(cmd) + self.out.decreaseIndent() + + if len(ir.falseCmds) == 0: + self.out.printf("}\n", indent=True) + return + + self.out.printf("} else {\n", indent=True) + + self.out.increaseIndent() + for cmd in ir.falseCmds: + self.print(cmd) + self.out.decreaseIndent() + + self.out.printf("}\n", indent=True) + + def printFor(self, ir): + self.printForHeader(ir) + self.out.increaseIndent() + for cmd in ir.cmd_l: + self.print(cmd) + self.out.decreaseIndent() + self.out.printf("}\n", indent=True) + + def printForHeader(self, ir): + self.out.printf("for (%s ", IR.DataType.getIntStr(), indent=True) + self.print(ir.var) + self.out.printf(" = %d; ", ir.st) + self.print(ir.cond) + self.out.printf("; ") + self.print(ir.var) + self.out.printf("++) {\n") + + def printWhile(self, ir): + self.out.printf("while (", indent=True) + self.print(ir.expr) + self.out.printf(") {\n") + self.out.increaseIndent() + for cmd in ir.cmds: + self.print(cmd) + self.out.decreaseIndent() + self.out.printf("}\n", indent=True) + + def printProg(self, ir): + for cmd in ir.cmd_l: + self.print(cmd) + + def printPrint(self, ir): + self.out.printf("cout << ", indent=True) + self.print(ir.expr) + self.out.printf(" << endl;\n") + + def printPrintAsFloat(self, ir): + self.out.printf("cout << ((float)(", indent=True) + self.print(ir.expr) + self.out.printf(")) * " + str(2 ** ir.expnt) + ' << "";\n') + + def print(self, ir): + if isinstance(ir, IR.Int): + return self.printInt(ir) + elif isinstance(ir, IR.Var): + return self.printVar(ir) + elif isinstance(ir, IR.Bool): + return self.printBool(ir) + elif isinstance(ir, IR.IntUop): + return self.printIntUop(ir) + elif isinstance(ir, IR.Exp): + return self.printExp(ir) + elif isinstance(ir, IR.TypeCast): + return self.printTypeCast(ir) + elif isinstance(ir, IR.IntBop): + return self.printIntBop(ir) + elif isinstance(ir, IR.BoolUop): + return self.printBoolUop(ir) + elif isinstance(ir, IR.BoolBop): + return self.printBoolBop(ir) + elif isinstance(ir, IR.BoolCop): + return self.printBoolCop(ir) + elif isinstance(ir, IR.CExpr): + return self.printCExpr(ir) + elif isinstance(ir, IR.Assn): + return self.printAssn(ir) + elif isinstance(ir, IR.If): + return self.printIf(ir) + elif isinstance(ir, IR.For): + return self.printFor(ir) + elif isinstance(ir, IR.While): + return self.printWhile(ir) + elif isinstance(ir, IR.Comment): + return self.printComment(ir) + elif isinstance(ir, IR.Pragmas): + return self.printPragmas(ir) + elif isinstance(ir, IR.Prog): + return self.printProg(ir) + elif isinstance(ir, IR.Memset): + return self.printMemset(ir) + elif isinstance(ir, IR.Print): + return self.printPrint(ir) + elif isinstance(ir, IR.PrintAsFloat): + return self.printPrintAsFloat(ir) + elif isinstance(ir, IR.FuncCall): + return self.printFuncCall(ir) + elif isinstance(ir, IR.Op.Op): + return self.printOp(ir) + elif isinstance(ir, IR.Input): + return self.printInput(ir) + elif isinstance(ir, IR.Decl): + return self.printDecl(ir) + else: + assert False diff --git a/Athos/SeeDot/Codegen/EzPC.py b/Athos/SeeDot/Codegen/EzPC.py index a14681a2..ac5a3589 100644 --- a/Athos/SeeDot/Codegen/EzPC.py +++ b/Athos/SeeDot/Codegen/EzPC.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import Util import IR.IR as IR @@ -29,115 +29,131 @@ import IR.IRUtil as IRUtil from Codegen.CodegenBase import CodegenBase + class EzPC(CodegenBase): - def __init__(self, writer, globalDecls, debugVar): - self.out = writer - self.globalDecls = globalDecls - self.consSFUsed = Util.Config.consSF - self.debugVar = debugVar - - def printAll(self, prog:IR.Prog, expr:IR.Expr): - self._out_prefix() - self.print(prog) - self._out_suffix(expr) - - def _out_prefix(self): - self.out.printf('\n\ndef void main(){\n') - self.out.increaseIndent() - self.printGlobalVarDecls() - - def printGlobalVarDecls(self): - for decl in self.globalDecls: - typ_str = IR.DataType.getIntStr() - idf_str = decl - declProp = self.globalDecls[decl] - curType = declProp[0] - if (len(declProp)>1): - # If label specified in decl, then use that - assert(len(declProp) >= 2 and len(declProp) <= 3) #For now only type, label and bitlen should be present - variableLabel = ('pl' if (declProp[1] == 'public') else 'al') - if (len(declProp) == 3): - bitlen = declProp[2] - typ_str = IR.DataType.getIntStrForBitlen(bitlen) - else: - # If variable unspecified, then default to secret - variableLabel = 'al' - if Type.isInt(curType): shape_str = '' - elif Type.isTensor(curType): shape_str = ''.join(['[' + str(n) + ']' for n in curType.shape]) - self.out.printf('%s_%s%s %s;\n', typ_str, variableLabel, shape_str, idf_str, indent=True) - self.out.printf('\n') - - def printFuncCall(self, ir:IR.FuncCall): - self.out.printf("%s(" % ir.name, indent = True) - keys = list(ir.argList) - for i in range(len(keys)): - arg = keys[i] - self.print(arg) - if i != len(keys) - 1: - self.out.printf(", ") - self.out.printf(");\n\n") - - def printForHeader(self, ir): - assert(ir.endInt is not None and ir.endCond is None) - self.out.printf('for ', indent=True) - self.print(ir.var) - self.out.printf(' = [%d: %d]{\n ', ir.st, ir.endInt) - - def printFor(self, ir): - self.printForHeader(ir) - self.out.increaseIndent() - for cmd in ir.cmd_l: - self.print(cmd) - self.out.decreaseIndent() - self.out.printf('};\n', indent=True) - - def printInt(self, ir:IR.Int): - if (isinstance(ir.n, np.int32)): - self.out.printf('%d', ir.n) - elif (isinstance(ir.n, np.int64)): - self.out.printf('%dL', ir.n) - else: - assert False - - def printInput(self, ir:IR.Input): - inputByPartyStr = ir.inputByParty.name - assert(inputByPartyStr == "SERVER" or inputByPartyStr == "CLIENT") #For now the only supported values of party to input is 0 or 1 - self.out.printf('input({0}, {1}, '.format(inputByPartyStr, ir.expr.idf), indent=True) - #assert(ir.dataType in ["DT_INT32"]) ####TODO: fix this - if Util.Config.wordLength == 32: - self.out.printf('int32_') - elif Util.Config.wordLength == 64: - self.out.printf('int64_') - else: - assert False - if ir.isSecret: - self.out.printf('al') - else: - self.out.printf('pl') - for curDim in ir.shape: - self.out.printf('[' + str(curDim) + ']') - self.out.printf(');\n\n') - - def printComment(self, ir): - self.out.printf('(* ' + ir.msg + ' *)\n', indent = True) - - def printDecl(self, ir): - typ_str = IR.DataType.getIntStrForBitlen(ir.bitlen) - variableLabel = 'pl' if not(ir.isSecret) else 'al' - - if Type.isInt(ir.typeExpr): shape_str = '' - elif Type.isTensor(ir.typeExpr): shape_str = ''.join(['[' + str(n) + ']' for n in ir.typeExpr.shape]) - self.out.printf('%s_%s%s %s', typ_str, variableLabel, shape_str, ir.varIdf, indent=True) - if (ir.value): - assert(Type.isInt(ir.typeExpr)) #In EzPC ints can be declared and assigned in same line, not tensors - self.out.printf(' = %s', str(ir.value[0])) - self.out.printf(';\n\n') - - def _out_suffix(self, expr:IR.Expr): - if self.debugVar is None: - self.out.printf('output(CLIENT, ' + expr.idf + ');\n', indent=True) - else: - self.out.printf('output(CLIENT, ' + self.debugVar + ');\n', indent=True) - self.out.decreaseIndent() - self.out.printf('}\n', indent=True) - + def __init__(self, writer, globalDecls, debugVar): + self.out = writer + self.globalDecls = globalDecls + self.consSFUsed = Util.Config.consSF + self.debugVar = debugVar + + def printAll(self, prog: IR.Prog, expr: IR.Expr): + self._out_prefix() + self.print(prog) + self._out_suffix(expr) + + def _out_prefix(self): + self.out.printf("\n\ndef void main(){\n") + self.out.increaseIndent() + self.printGlobalVarDecls() + + def printGlobalVarDecls(self): + for decl in self.globalDecls: + typ_str = IR.DataType.getIntStr() + idf_str = decl + declProp = self.globalDecls[decl] + curType = declProp[0] + if len(declProp) > 1: + # If label specified in decl, then use that + assert ( + len(declProp) >= 2 and len(declProp) <= 3 + ) # For now only type, label and bitlen should be present + variableLabel = "pl" if (declProp[1] == "public") else "al" + if len(declProp) == 3: + bitlen = declProp[2] + typ_str = IR.DataType.getIntStrForBitlen(bitlen) + else: + # If variable unspecified, then default to secret + variableLabel = "al" + if Type.isInt(curType): + shape_str = "" + elif Type.isTensor(curType): + shape_str = "".join(["[" + str(n) + "]" for n in curType.shape]) + self.out.printf( + "%s_%s%s %s;\n", typ_str, variableLabel, shape_str, idf_str, indent=True + ) + self.out.printf("\n") + + def printFuncCall(self, ir: IR.FuncCall): + self.out.printf("%s(" % ir.name, indent=True) + keys = list(ir.argList) + for i in range(len(keys)): + arg = keys[i] + self.print(arg) + if i != len(keys) - 1: + self.out.printf(", ") + self.out.printf(");\n\n") + + def printForHeader(self, ir): + assert ir.endInt is not None and ir.endCond is None + self.out.printf("for ", indent=True) + self.print(ir.var) + self.out.printf(" = [%d: %d]{\n ", ir.st, ir.endInt) + + def printFor(self, ir): + self.printForHeader(ir) + self.out.increaseIndent() + for cmd in ir.cmd_l: + self.print(cmd) + self.out.decreaseIndent() + self.out.printf("};\n", indent=True) + + def printInt(self, ir: IR.Int): + if isinstance(ir.n, np.int32): + self.out.printf("%d", ir.n) + elif isinstance(ir.n, np.int64): + self.out.printf("%dL", ir.n) + else: + assert False + + def printInput(self, ir: IR.Input): + inputByPartyStr = ir.inputByParty.name + assert ( + inputByPartyStr == "SERVER" or inputByPartyStr == "CLIENT" + ) # For now the only supported values of party to input is 0 or 1 + self.out.printf( + "input({0}, {1}, ".format(inputByPartyStr, ir.expr.idf), indent=True + ) + # assert(ir.dataType in ["DT_INT32"]) ####TODO: fix this + if Util.Config.wordLength == 32: + self.out.printf("int32_") + elif Util.Config.wordLength == 64: + self.out.printf("int64_") + else: + assert False + if ir.isSecret: + self.out.printf("al") + else: + self.out.printf("pl") + for curDim in ir.shape: + self.out.printf("[" + str(curDim) + "]") + self.out.printf(");\n\n") + + def printComment(self, ir): + self.out.printf("(* " + ir.msg + " *)\n", indent=True) + + def printDecl(self, ir): + typ_str = IR.DataType.getIntStrForBitlen(ir.bitlen) + variableLabel = "pl" if not (ir.isSecret) else "al" + + if Type.isInt(ir.typeExpr): + shape_str = "" + elif Type.isTensor(ir.typeExpr): + shape_str = "".join(["[" + str(n) + "]" for n in ir.typeExpr.shape]) + self.out.printf( + "%s_%s%s %s", typ_str, variableLabel, shape_str, ir.varIdf, indent=True + ) + if ir.value: + assert Type.isInt( + ir.typeExpr + ) # In EzPC ints can be declared and assigned in same line, not tensors + self.out.printf(" = %s", str(ir.value[0])) + self.out.printf(";\n\n") + + def _out_suffix(self, expr: IR.Expr): + if self.debugVar is None: + self.out.printf("output(CLIENT, " + expr.idf + ");\n", indent=True) + else: + self.out.printf("output(CLIENT, " + self.debugVar + ");\n", indent=True) + self.out.decreaseIndent() + self.out.printf("}\n", indent=True) diff --git a/Athos/SeeDot/Compiler.py b/Athos/SeeDot/Compiler.py index d81d536f..ffed6aad 100644 --- a/Athos/SeeDot/Compiler.py +++ b/Athos/SeeDot/Compiler.py @@ -1,4 +1,4 @@ -''' +""" Authors: Sridhar Gopinath, Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import os, sys import _pickle as pickle @@ -32,7 +32,7 @@ import Type as Type from Type import InferType import IR.IRUtil as IRUtil -from AST.PrintAST import PrintAST +from AST.PrintAST import PrintAST from AST.MtdAST import MtdAST from IR.IRBuilderCSF import IRBuilderCSF from Codegen.EzPC import EzPC as EzPCCodegen @@ -40,114 +40,139 @@ import Optimizations.GarbageCollector as GarbageCollector from collections import OrderedDict + class Compiler: - def __init__(self, version, target, sfType, astFile, printASTBool, consSF, bitlen, outputFileName, - disableRMO, disableLivenessOpti, disableTruncOpti, disableAllOpti, debugVar): - assert(version == Util.Version.Fixed) - assert(target == Util.Target.EzPC) - assert(sfType == Util.SFType.Constant) - assert(astFile is not None) - assert(isinstance(printASTBool, bool)) - assert(consSF is not None) - assert(bitlen is not None) - assert(outputFileName is not None) - Util.Config.version = version - Util.Config.target = target - Util.Config.sfType = sfType - Util.Config.astFile = astFile - Util.Config.printASTBool = printASTBool - Util.Config.consSF = consSF - Util.Config.outputFileName = outputFileName - Util.Config.disableRMO = disableRMO - Util.Config.disableLivenessOpti = disableLivenessOpti - Util.Config.disableTruncOpti = disableTruncOpti - Util.Config.disableAllOpti = disableAllOpti - Util.Config.debugVar = debugVar - Util.Config.actualWordLength = int(bitlen) - if (Util.Config.actualWordLength > 32): - Util.Config.wordLength = 64 - else: - Util.Config.wordLength = 32 - - def insertStartEndFunctionCalls(self, res:(IR.Prog, IR.Expr)): - prog = res[0] - expr = res[1] - for ii in range(len(prog.cmd_l)): - if not(isinstance(prog.cmd_l[ii], IR.Input)) and not(isinstance(prog.cmd_l[ii], IR.Comment)): - prog.cmd_l.insert(ii, IR.FuncCall('StartComputation',[])) - break; - prog.cmd_l.append(IR.FuncCall('EndComputation', [])) - return (prog, expr) - - def fixOuputScale(self, res:(IR.Prog, IR.Expr), compiler:IRBuilderCSF): - prog = res[0] - expr = res[1] - output_scale = compiler.scaleFacMapping[expr.idf] - if output_scale == -1 or output_scale == Util.Config.consSF: - return (prog, expr) - elif output_scale > Util.Config.consSF: - scale_down = output_scale - Util.Config.consSF - type = compiler.typeInfo[expr.idf] - if Type.isInt(type): - output_shape = [] - if Type.isTensor(type): - output_shape = type.shape - - argsDict = OrderedDict() - funcName = "ScaleDown" - for ii, curDimSize in enumerate(output_shape): - argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) - funcName = funcName + str(len(output_shape)) - argsDict[expr] = "expr" - argsDict[IR.Int(scale_down,32)] = "consSF" - funcCall = IR.FuncCall(funcName, argsDict) - new_prog = IR.Prog([funcCall]) - prog = IRUtil.prog_merge(prog, new_prog) - return (prog, expr) - else: - assert False, "Scale up shouldnt be required of final output {} -> {}. We lost precision somewhere".format(output_scale, Util.Config.consSF) - - def run(self): - with open(Util.Config.astFile, 'rb') as ff: - ast = pickle.load(ff) - - if not(Util.Config.disableAllOpti): - if not(Util.Config.disableRMO): - print("Performing Relu-maxpool optimization...") - ReluMaxpoolOpti.ReluMaxpoolOpti().visit(ast) - print("Relu-maxpool optimization done.") - - if not(Util.Config.disableLivenessOpti): - print("Performing Garbage collection...") - mtdAST = MtdAST() - GC = GarbageCollector.GarbageCollector(ast) - GC.run([mtdAST]) - print("Garbage collection done.") - - # Perform type inference and annotate nodes with type information - InferType().visit(ast) - - if Util.Config.printASTBool: - PrintAST().visit(ast) - print("\n") - sys.stdout.flush() - - IRUtil.init() - compiler = IRBuilderCSF() - res = compiler.visit(ast) - res = self.fixOuputScale(res, compiler); - - Util.write_debug_info(compiler.name_mapping) - - # Insert a generic start_computation and end_computation function call after all input IR statements. - res = self.insertStartEndFunctionCalls(res); - writer = Writer(Util.Config.outputFileName) - debugVarEzPCName = compiler.name_mapping[Util.Config.debugVar] if (Util.Config.debugVar in compiler.name_mapping) else None - - if Util.forEzPC(): - codegen = EzPCCodegen(writer, compiler.globalDecls, debugVarEzPCName) - else: - assert False - - codegen.printAll(*res) - writer.close() + def __init__( + self, + version, + target, + sfType, + astFile, + printASTBool, + consSF, + bitlen, + outputFileName, + disableRMO, + disableLivenessOpti, + disableTruncOpti, + disableAllOpti, + debugVar, + ): + assert version == Util.Version.Fixed + assert target == Util.Target.EzPC + assert sfType == Util.SFType.Constant + assert astFile is not None + assert isinstance(printASTBool, bool) + assert consSF is not None + assert bitlen is not None + assert outputFileName is not None + Util.Config.version = version + Util.Config.target = target + Util.Config.sfType = sfType + Util.Config.astFile = astFile + Util.Config.printASTBool = printASTBool + Util.Config.consSF = consSF + Util.Config.outputFileName = outputFileName + Util.Config.disableRMO = disableRMO + Util.Config.disableLivenessOpti = disableLivenessOpti + Util.Config.disableTruncOpti = disableTruncOpti + Util.Config.disableAllOpti = disableAllOpti + Util.Config.debugVar = debugVar + Util.Config.actualWordLength = int(bitlen) + if Util.Config.actualWordLength > 32: + Util.Config.wordLength = 64 + else: + Util.Config.wordLength = 32 + + def insertStartEndFunctionCalls(self, res: (IR.Prog, IR.Expr)): + prog = res[0] + expr = res[1] + for ii in range(len(prog.cmd_l)): + if not (isinstance(prog.cmd_l[ii], IR.Input)) and not ( + isinstance(prog.cmd_l[ii], IR.Comment) + ): + prog.cmd_l.insert(ii, IR.FuncCall("StartComputation", [])) + break + prog.cmd_l.append(IR.FuncCall("EndComputation", [])) + return (prog, expr) + + def fixOuputScale(self, res: (IR.Prog, IR.Expr), compiler: IRBuilderCSF): + prog = res[0] + expr = res[1] + output_scale = compiler.scaleFacMapping[expr.idf] + if output_scale == -1 or output_scale == Util.Config.consSF: + return (prog, expr) + elif output_scale > Util.Config.consSF: + scale_down = output_scale - Util.Config.consSF + type = compiler.typeInfo[expr.idf] + if Type.isInt(type): + output_shape = [] + if Type.isTensor(type): + output_shape = type.shape + + argsDict = OrderedDict() + funcName = "ScaleDown" + for ii, curDimSize in enumerate(output_shape): + argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) + funcName = funcName + str(len(output_shape)) + argsDict[expr] = "expr" + argsDict[IR.Int(scale_down, 32)] = "consSF" + funcCall = IR.FuncCall(funcName, argsDict) + new_prog = IR.Prog([funcCall]) + prog = IRUtil.prog_merge(prog, new_prog) + return (prog, expr) + else: + assert ( + False + ), "Scale up shouldnt be required of final output {} -> {}. We lost precision somewhere".format( + output_scale, Util.Config.consSF + ) + + def run(self): + with open(Util.Config.astFile, "rb") as ff: + ast = pickle.load(ff) + + if not (Util.Config.disableAllOpti): + if not (Util.Config.disableRMO): + print("Performing Relu-maxpool optimization...") + ReluMaxpoolOpti.ReluMaxpoolOpti().visit(ast) + print("Relu-maxpool optimization done.") + + if not (Util.Config.disableLivenessOpti): + print("Performing Garbage collection...") + mtdAST = MtdAST() + GC = GarbageCollector.GarbageCollector(ast) + GC.run([mtdAST]) + print("Garbage collection done.") + + # Perform type inference and annotate nodes with type information + InferType().visit(ast) + + if Util.Config.printASTBool: + PrintAST().visit(ast) + print("\n") + sys.stdout.flush() + + IRUtil.init() + compiler = IRBuilderCSF() + res = compiler.visit(ast) + res = self.fixOuputScale(res, compiler) + + Util.write_debug_info(compiler.name_mapping) + + # Insert a generic start_computation and end_computation function call after all input IR statements. + res = self.insertStartEndFunctionCalls(res) + writer = Writer(Util.Config.outputFileName) + debugVarEzPCName = ( + compiler.name_mapping[Util.Config.debugVar] + if (Util.Config.debugVar in compiler.name_mapping) + else None + ) + + if Util.forEzPC(): + codegen = EzPCCodegen(writer, compiler.globalDecls, debugVarEzPCName) + else: + assert False + + codegen.printAll(*res) + writer.close() diff --git a/Athos/SeeDot/IR/IR.py b/Athos/SeeDot/IR/IR.py index 50362d0e..89c87102 100644 --- a/Athos/SeeDot/IR/IR.py +++ b/Athos/SeeDot/IR/IR.py @@ -1,4 +1,4 @@ -''' +""" Authors: Sridhar Gopinath, Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" from enum import Enum import numpy as np @@ -28,326 +28,421 @@ import Util, Type import AST.AST as AST -#TODO - check if this can be cleaned up -class Op(): - Op = Enum('Op', '+ - * / << >> & | ^ ~ ! && || < <= > >= == != max .* ./') - Op.print = lambda self, writer: writer.printf('%s', self.name) - Op.op_list = lambda op_str: list(map(lambda x: Op.Op[x], op_str.split())) # op_str:str - Op.IsPrefixOp = lambda self: True if (self.name == 'max') else False - Op.IsPostfixOp = lambda self: False +# TODO - check if this can be cleaned up +class Op: + Op = Enum("Op", "+ - * / << >> & | ^ ~ ! && || < <= > >= == != max .* ./") + Op.print = lambda self, writer: writer.printf("%s", self.name) + Op.op_list = lambda op_str: list( + map(lambda x: Op.Op[x], op_str.split()) + ) # op_str:str + Op.IsPrefixOp = lambda self: True if (self.name == "max") else False + Op.IsPostfixOp = lambda self: False + class Expr: - pass + pass + class IntExpr(Expr): - pass + pass + class BoolExpr(Expr): - pass + pass + class Int(IntExpr): - @staticmethod - def negMax(): - return DataType.getNegMax() - @staticmethod - def max(): - return DataType.getMax() - def __init__(self, n:int, wordLen:int=None): - if not(wordLen): - wordLen = Util.Config.wordLength - self.n = DataType.getInt(n, wordLen) - def subst(self, from_idf:str, to_e:Expr): - if isinstance(self.n, np.int8): - return self.__class__(self.n, 8) - elif isinstance(self.n, np.int16): - return self.__class__(self.n, 16) - elif isinstance(self.n, np.int32): - return self.__class__(self.n, 32) - elif isinstance(self.n, np.int64): - return self.__class__(self.n, 64) - else: - assert False - return self.__class__(self.n) + @staticmethod + def negMax(): + return DataType.getNegMax() + + @staticmethod + def max(): + return DataType.getMax() + + def __init__(self, n: int, wordLen: int = None): + if not (wordLen): + wordLen = Util.Config.wordLength + self.n = DataType.getInt(n, wordLen) + + def subst(self, from_idf: str, to_e: Expr): + if isinstance(self.n, np.int8): + return self.__class__(self.n, 8) + elif isinstance(self.n, np.int16): + return self.__class__(self.n, 16) + elif isinstance(self.n, np.int32): + return self.__class__(self.n, 32) + elif isinstance(self.n, np.int64): + return self.__class__(self.n, 64) + else: + assert False + return self.__class__(self.n) + class Var(IntExpr): - def __init__(self, idf:str, idx:list=[], inputVar=False): - self.idf = idf - self.idx = idx - self.inputVar = inputVar - def subst(self, from_idf:str, to_e:Expr): - idx_new = list(map(lambda e: e.subst(from_idf, to_e), self.idx)) - if(self.idf != from_idf): - return self.__class__(self.idf, idx_new, self.inputVar) - else: - if(isinstance(to_e, Var)): - return self.__class__(to_e.idf, to_e.idx + idx_new, to_e.inputVar and self.inputVar) - elif(isinstance(to_e, Int)): - return to_e - else: assert False + def __init__(self, idf: str, idx: list = [], inputVar=False): + self.idf = idf + self.idx = idx + self.inputVar = inputVar + + def subst(self, from_idf: str, to_e: Expr): + idx_new = list(map(lambda e: e.subst(from_idf, to_e), self.idx)) + if self.idf != from_idf: + return self.__class__(self.idf, idx_new, self.inputVar) + else: + if isinstance(to_e, Var): + return self.__class__( + to_e.idf, to_e.idx + idx_new, to_e.inputVar and self.inputVar + ) + elif isinstance(to_e, Int): + return to_e + else: + assert False + class Bool(BoolExpr): - def __init__(self, b:bool): - self.b = b - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.b) + def __init__(self, b: bool): + self.b = b + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__(self.b) + class IntUop(IntExpr): - def __init__(self, op:Op.Op, e:IntExpr): - assert(op in Op.Op.op_list('- ~')) - self.op = op - self.e = e - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.op, - self.e.subst(from_idf, to_e)) + def __init__(self, op: Op.Op, e: IntExpr): + assert op in Op.Op.op_list("- ~") + self.op = op + self.e = e + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__(self.op, self.e.subst(from_idf, to_e)) + class Exp(IntExpr): - def __init__(self, e:IntExpr): - self.e = e - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.e.subst(from_idf, to_e)) + def __init__(self, e: IntExpr): + self.e = e + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__(self.e.subst(from_idf, to_e)) + class TypeCast(IntExpr): - def __init__(self, type, expr:Expr): - self.type = type - self.expr = expr - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.type, self.expr.subst(from_idf, to_e)) - + def __init__(self, type, expr: Expr): + self.type = type + self.expr = expr + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__(self.type, self.expr.subst(from_idf, to_e)) + + class IntBop(IntExpr): - def __init__(self, e1:IntExpr, op:Op.Op, e2:IntExpr): - assert(op in Op.Op.op_list('+ - * / << >> & | ^ ==')) - self.e1 = e1 - self.op = op - self.e2 = e2 - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.e1.subst(from_idf, to_e), - self.op, - self.e2.subst(from_idf, to_e)) + def __init__(self, e1: IntExpr, op: Op.Op, e2: IntExpr): + assert op in Op.Op.op_list("+ - * / << >> & | ^ ==") + self.e1 = e1 + self.op = op + self.e2 = e2 + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__( + self.e1.subst(from_idf, to_e), self.op, self.e2.subst(from_idf, to_e) + ) + class BoolUop(BoolExpr): - def __init__(self, op:Op.Op, e:BoolExpr): - assert(op in Op.Op.op_list('')) # ! - self.op = op - self.e = e - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.op, - self.e.subst(from_idf, to_e)) + def __init__(self, op: Op.Op, e: BoolExpr): + assert op in Op.Op.op_list("") # ! + self.op = op + self.e = e + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__(self.op, self.e.subst(from_idf, to_e)) + class BoolBop(BoolExpr): - def __init__(self, e1:BoolExpr, op:Op.Op, e2:BoolExpr): - assert(op in Op.Op.op_list('&& ||')) # || ^ - self.e1 = e1 - self.op = op - self.e2 = e2 - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.e1.subst(from_idf, to_e), - self.op, - self.e2.subst(from_idf, to_e)) + def __init__(self, e1: BoolExpr, op: Op.Op, e2: BoolExpr): + assert op in Op.Op.op_list("&& ||") # || ^ + self.e1 = e1 + self.op = op + self.e2 = e2 + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__( + self.e1.subst(from_idf, to_e), self.op, self.e2.subst(from_idf, to_e) + ) + class BoolCop(BoolExpr): - def __init__(self, e1:IntExpr, op:Op.Op, e2:IntExpr): - assert(op in Op.Op.op_list('< <= > >= == !=')) # >= <= != - self.e1 = e1 - self.op = op - self.e2 = e2 - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.e1.subst(from_idf, to_e), - self.op, - self.e2.subst(from_idf, to_e)) + def __init__(self, e1: IntExpr, op: Op.Op, e2: IntExpr): + assert op in Op.Op.op_list("< <= > >= == !=") # >= <= != + self.e1 = e1 + self.op = op + self.e2 = e2 + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__( + self.e1.subst(from_idf, to_e), self.op, self.e2.subst(from_idf, to_e) + ) + class CExpr(Expr): - def __init__(self, cond:BoolExpr, et:Expr, ef:Expr): - self.cond = cond - self.et = et - self.ef = ef - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.cond.subst(from_idf, to_e), - self.et .subst(from_idf, to_e), - self.ef .subst(from_idf, to_e)) + def __init__(self, cond: BoolExpr, et: Expr, ef: Expr): + self.cond = cond + self.et = et + self.ef = ef + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__( + self.cond.subst(from_idf, to_e), + self.et.subst(from_idf, to_e), + self.ef.subst(from_idf, to_e), + ) + class Cmd: - pass + pass + class CmdList: - pass + pass + class Assn(Cmd): - def __init__(self, var:Var, e:Expr): - self.var = var - self.e = e - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.var.subst(from_idf, to_e), self.e.subst(from_idf, to_e)) + def __init__(self, var: Var, e: Expr): + self.var = var + self.e = e + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__( + self.var.subst(from_idf, to_e), self.e.subst(from_idf, to_e) + ) + class If(Cmd): - def __init__(self, cond:Expr, trueCmds:CmdList, falseCmds:CmdList=[]): - self.cond = cond - self.trueCmds = trueCmds - self.falseCmds = falseCmds - def subst(self, from_idf:str, to_e:Expr): - trueCmdsNew = list(map(lambda cmd: cmd.subst(from_idf, to_e), self.trueCmds)) - falseCmdsNew = list(map(lambda cmd: cmd.subst(from_idf, to_e), self.falseCmds)) - return self.__class__(self.cond.subst(from_idf, to_e), trueCmdsNew, falseCmdsNew) + def __init__(self, cond: Expr, trueCmds: CmdList, falseCmds: CmdList = []): + self.cond = cond + self.trueCmds = trueCmds + self.falseCmds = falseCmds + + def subst(self, from_idf: str, to_e: Expr): + trueCmdsNew = list(map(lambda cmd: cmd.subst(from_idf, to_e), self.trueCmds)) + falseCmdsNew = list(map(lambda cmd: cmd.subst(from_idf, to_e), self.falseCmds)) + return self.__class__( + self.cond.subst(from_idf, to_e), trueCmdsNew, falseCmdsNew + ) + class For(Cmd): - ''' - The terminationCond keyword arg should either consist of ending integer for the loop (keyword - endInt) - or the actual condition (keyword - endCond). - ''' - __endIntArgStr = 'endInt' - __endCondArgStr = 'endCond' - def __init__(self, var:Var, st:int, cmd_l:CmdList, fac=0, **terminationCond): - self.var = var - self.st = DataType.getInt(st) - self.cmd_l = cmd_l - self.factor = fac - self.endInt = None - self.endCond = None - if self.__endIntArgStr in terminationCond: - self.endInt = terminationCond[self.__endIntArgStr] - elif self.__endCondArgStr in terminationCond: - self.endCond = terminationCond[self.__endCondArgStr] - else: - assert False - - def subst(self, from_idf:str, to_e:Expr): - cmd_l_new = list(map(lambda cmd: cmd.subst(from_idf, to_e), self.cmd_l)) - if self.endCond: - return For(self.var, self.st, cmd_l_new, self.factor, endCond=self.cond.subst(from_idf, to_e)) - else: - assert self.endInt is not None - return For(self.var, self.st, cmd_l_new, self.factor, endInt=self.endInt) + """ + The terminationCond keyword arg should either consist of ending integer for the loop (keyword - endInt) + or the actual condition (keyword - endCond). + """ + + __endIntArgStr = "endInt" + __endCondArgStr = "endCond" + + def __init__(self, var: Var, st: int, cmd_l: CmdList, fac=0, **terminationCond): + self.var = var + self.st = DataType.getInt(st) + self.cmd_l = cmd_l + self.factor = fac + self.endInt = None + self.endCond = None + if self.__endIntArgStr in terminationCond: + self.endInt = terminationCond[self.__endIntArgStr] + elif self.__endCondArgStr in terminationCond: + self.endCond = terminationCond[self.__endCondArgStr] + else: + assert False + + def subst(self, from_idf: str, to_e: Expr): + cmd_l_new = list(map(lambda cmd: cmd.subst(from_idf, to_e), self.cmd_l)) + if self.endCond: + return For( + self.var, + self.st, + cmd_l_new, + self.factor, + endCond=self.cond.subst(from_idf, to_e), + ) + else: + assert self.endInt is not None + return For(self.var, self.st, cmd_l_new, self.factor, endInt=self.endInt) + class While(Cmd): - def __init__(self, expr:BoolExpr, cmds:CmdList): - self.expr = expr - self.cmds = cmds - def subst(self, from_idf:str, to_e:Expr): - cmds_new = list(map(lambda cmd: cmd.subst(from_idf, to_e), self.cmds)) - return While(self.expr.subst(from_idf, to_e), cmds_new) + def __init__(self, expr: BoolExpr, cmds: CmdList): + self.expr = expr + self.cmds = cmds + + def subst(self, from_idf: str, to_e: Expr): + cmds_new = list(map(lambda cmd: cmd.subst(from_idf, to_e), self.cmds)) + return While(self.expr.subst(from_idf, to_e), cmds_new) + class Comment(Cmd): - def __init__(self, msg): - self.msg = msg - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.msg) + def __init__(self, msg): + self.msg = msg + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__(self.msg) + class Pragmas(Cmd): - def __init__(self, msg, vital=0): - self.msg = msg - self.vital = vital - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.msg, self.vital) - -class Prog(): - def __init__(self, cmd_l:CmdList, resource=0): - self.cmd_l = cmd_l - self.resource = resource - def subst(self, from_idf:str, to_e:Expr): - cmd_l_new = list(map(lambda cmd: cmd.subst(from_idf, to_e), self.cmd_l)) - return self.__class__(cmd_l_new, self.resource) + def __init__(self, msg, vital=0): + self.msg = msg + self.vital = vital + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__(self.msg, self.vital) + + +class Prog: + def __init__(self, cmd_l: CmdList, resource=0): + self.cmd_l = cmd_l + self.resource = resource + + def subst(self, from_idf: str, to_e: Expr): + cmd_l_new = list(map(lambda cmd: cmd.subst(from_idf, to_e), self.cmd_l)) + return self.__class__(cmd_l_new, self.resource) + class Memset(Cmd): - #if dim==1 then single for-loop memset, else memset for 'dim' - def __init__(self, e:Var, len:int,dim=1, lens=[]): - self.e = e - self.len = len - self.dim = dim - self.lens = lens - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.e.subst(from_idf, to_e), self.len) + # if dim==1 then single for-loop memset, else memset for 'dim' + def __init__(self, e: Var, len: int, dim=1, lens=[]): + self.e = e + self.len = len + self.dim = dim + self.lens = lens + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__(self.e.subst(from_idf, to_e), self.len) + class Print(Cmd): - def __init__(self, expr:Expr): - self.expr = expr - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.expr.subst(from_idf, to_e)) + def __init__(self, expr: Expr): + self.expr = expr + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__(self.expr.subst(from_idf, to_e)) + class PrintAsFloat(Cmd): - def __init__(self, expr:Expr, expnt:int): - self.expr = expr - self.expnt = expnt - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.expr.subst(from_idf, to_e), self.expnt) + def __init__(self, expr: Expr, expnt: int): + self.expr = expr + self.expnt = expnt + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__(self.expr.subst(from_idf, to_e), self.expnt) + class FuncCall(Cmd): - def __init__(self, name, argList): - self.name = name - self.argList = argList - def subst(self, from_idf:str, to_e:Expr): - #argList_new = list(map(lambda cmd: cmd.subst(from_idf, to_e), self.argList)) - argList_new = dict(map(lambda cmd: (cmd[0].subst(from_idf, to_e), cmd[1]), self.argList.items())) - return self.__class__(self.name, argList_new) + def __init__(self, name, argList): + self.name = name + self.argList = argList + + def subst(self, from_idf: str, to_e: Expr): + # argList_new = list(map(lambda cmd: cmd.subst(from_idf, to_e), self.argList)) + argList_new = dict( + map( + lambda cmd: (cmd[0].subst(from_idf, to_e), cmd[1]), self.argList.items() + ) + ) + return self.__class__(self.name, argList_new) + class Input(Cmd): - def __init__(self, expr:Expr, shape:list, dataType:str, isSecret=True, inputByParty=AST.Party.SERVER): - self.expr = expr - self.shape = shape - self.dataType = dataType - self.isSecret = isSecret - self.inputByParty = inputByParty + def __init__( + self, + expr: Expr, + shape: list, + dataType: str, + isSecret=True, + inputByParty=AST.Party.SERVER, + ): + self.expr = expr + self.shape = shape + self.dataType = dataType + self.isSecret = isSecret + self.inputByParty = inputByParty + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__( + self.expr.subst(from_idf, to_e), + self.shape, + self.dataType, + self.isSecret, + self.inputByParty, + ) - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.expr.subst(from_idf, to_e), self.shape, self.dataType, self.isSecret, self.inputByParty) class Decl(Cmd): - def __init__(self, varIdf:str, typeExpr:Type.Type, bitlen:int=-1, isSecret:bool=True, value:list=None): - self.varIdf = varIdf - self.typeExpr = typeExpr - self.bitlen = Util.Config.wordLength if bitlen==-1 else bitlen - self.isSecret = isSecret - if (value): - assert(isinstance(value,list)) - self.value = value - - def subst(self, from_idf:str, to_e:Expr): - return self.__class__(self.varIdf, self.typeExpr, self.bitlen, self.isSecret, self.value) - -class DataType(): - - intType = { Util.Target.EzPC: {32:np.int32, 64: np.int64} } - intStr = { Util.Target.EzPC: 'int' } - floatStr = "float" - - @staticmethod - def getInt(x:int, wordLen:int=None): - if not(wordLen): - wordLen = Util.Config.wordLength - target = Util.Config.target - return DataType.intType[target][wordLen](x) - - @staticmethod - def getIntClass(): - target = Util.Config.target - wordLen = Util.Config.wordLength - return DataType.intType[target][wordLen] - - @staticmethod - def getIntStr(): - target = Util.Config.target - potentialPrefix = DataType.intStr[target] - if (target == Util.Target.EzPC): - potentialPrefix = potentialPrefix + str(Util.Config.wordLength) - return potentialPrefix - - @staticmethod - def getIntStrForBitlen(bitlen): - target = Util.Config.target - potentialPrefix = DataType.intStr[target] - if (target == Util.Target.EzPC): - potentialPrefix = potentialPrefix + str(bitlen) - return potentialPrefix - - @staticmethod - def getFloatStr(): - return DataType.floatStr - - @staticmethod - def getNegMax(): - intClass = DataType.getIntClass() - return intClass(np.iinfo(intClass).min) - - @staticmethod - def getMax(): - intClass = DataType.getIntClass() - return intClass(np.iinfo(intClass).max) - + def __init__( + self, + varIdf: str, + typeExpr: Type.Type, + bitlen: int = -1, + isSecret: bool = True, + value: list = None, + ): + self.varIdf = varIdf + self.typeExpr = typeExpr + self.bitlen = Util.Config.wordLength if bitlen == -1 else bitlen + self.isSecret = isSecret + if value: + assert isinstance(value, list) + self.value = value + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__( + self.varIdf, self.typeExpr, self.bitlen, self.isSecret, self.value + ) + + +class DataType: + + intType = {Util.Target.EzPC: {32: np.int32, 64: np.int64}} + intStr = {Util.Target.EzPC: "int"} + floatStr = "float" + + @staticmethod + def getInt(x: int, wordLen: int = None): + if not (wordLen): + wordLen = Util.Config.wordLength + target = Util.Config.target + return DataType.intType[target][wordLen](x) + + @staticmethod + def getIntClass(): + target = Util.Config.target + wordLen = Util.Config.wordLength + return DataType.intType[target][wordLen] + + @staticmethod + def getIntStr(): + target = Util.Config.target + potentialPrefix = DataType.intStr[target] + if target == Util.Target.EzPC: + potentialPrefix = potentialPrefix + str(Util.Config.wordLength) + return potentialPrefix + + @staticmethod + def getIntStrForBitlen(bitlen): + target = Util.Config.target + potentialPrefix = DataType.intStr[target] + if target == Util.Target.EzPC: + potentialPrefix = potentialPrefix + str(bitlen) + return potentialPrefix + + @staticmethod + def getFloatStr(): + return DataType.floatStr + + @staticmethod + def getNegMax(): + intClass = DataType.getIntClass() + return intClass(np.iinfo(intClass).min) + + @staticmethod + def getMax(): + intClass = DataType.getIntClass() + return intClass(np.iinfo(intClass).max) diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index 8aa3370b..ee4697e2 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import os, math, sys import operator @@ -35,237 +35,307 @@ from AST.ASTVisitor import ASTVisitor from AST.IRBuilderAST import IRBuilderAST + class IRBuilderCSF(IRBuilderAST): - varNameDelim = '' - def __init__(self, intPartBitwidth=-1): - # For tracking temp variables - self._var_cnt = 0 - self._iter_cnt = 0 - - # Global variables - # Used to note declarations which will go before any statements - # But since this affects memory consumption, use carefully - self.globalDecls = {} #Mapping of (identifier name (string) -> list of [type, secret/public variable, bitlen of decl]) - # The 2nd arg can be either 'secret' or 'public'. - # If public/secret unspecified, default to 'secret'. - # The 3rd arg is used to specify the bitlen of the decl. - - # Name mapping from SeeDot names to new names is useful for debugging - self.name_mapping = {} - - self.actualbitwidth = Util.Config.actualWordLength - - #This is for optimizing the #truncation calls - self.scaleFac = Util.Config.consSF - self.bitwidth = Util.Config.wordLength - self.intPartBitwidth = intPartBitwidth - if (self.intPartBitwidth==-1): - self.intPartBitwidth = self.bitwidth - 2*self.scaleFac - self.scaleFacMapping = {} - - def getConsSF(self): - return Util.Config.consSF - - # Variable and iterators creation - def getTempVars(self, n:int): - return [self.getTempVar() for i in range(n)] - - def getTempVar(self): - var = IR.Var('tmp' + str(self._var_cnt)) - self._var_cnt += 1 - return var - - def getTempIterators(self, n:int): - return [self.getTempIterator() for i in range(n)] - - def getTempIterator(self): - var = IR.Var('i' + str(self._iter_cnt)) - self._iter_cnt += 1 - return var - - # Computing exponent and intervals - def get_expnt(self, maxabs:float): # -> int - return self.getConsSF() - - def addTruncateFunctionCallHelper(self, exprNumToScaleDown:int, expr1:IR.Var, expr2:IR.Var, node:AST.BOp): - assert(isinstance(node, AST.BOp)) - if (exprNumToScaleDown==1): - exprToScaleDown = expr1 - nodeToScaleDown = node.expr1 - else: - assert(exprNumToScaleDown==2) - exprToScaleDown = expr2 - nodeToScaleDown = node.expr2 - return (exprToScaleDown, nodeToScaleDown) - - def addTruncateFunctionCall(self, node:AST.ASTNode, nodeTypeStr: str, expr:IR.Var, consSF:int): - comment = IR.Comment("Truncation before {0} node.".format(nodeTypeStr)) - argsDict = OrderedDict() - funcName = "ScaleDown" - if not(Type.isInt(node.type)): - outputShape = node.type.shape - for ii,curDimSize in enumerate(outputShape): - argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) - funcName = funcName + str(len(outputShape)) - argsDict[expr] = "expr" - argsDict[IR.Int(consSF,32)] = "consSF" - funcCall = IR.FuncCall(funcName, argsDict) - prog = IR.Prog([comment, funcCall]) - return prog - - def isModel(self, node:AST.ASTNode): - if(node.type.taint == Type.Taints.SERVER): - return True - else: - return False - - #================= - # Visit Functions - #================= - - def visitInt(self, node:AST.Int, args=None): - n = node.value - prog = IR.Prog([IR.Comment('Int node, isSecret = {0}.'.format(node.isSecret))]) - expr = self.getTempVar() - bitlen = -1 - if node.bitLen: - bitlen = node.bitLen - prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type, bitlen, node.isSecret, [n])]), prog) - if (not(Util.Config.disableTruncOpti)): - self.scaleFacMapping[expr.idf] = self.scaleFac if node.isScaled else 0 - return (prog, expr) - - def visitFloat(self, node:AST.Float, args=None): - r = node.value - p = self.get_expnt(abs(r)) - k = IR.DataType.getInt(np.ldexp(r, p)) - expr = self.getTempVar() - prog = IR.Prog([IR.Comment('Float to int : {0} to {1}, isSecret = {2}.'.format(str(r), str(k), node.isSecret))]) - prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type, -1, node.isSecret, [k])]), prog) - if (not(Util.Config.disableTruncOpti)): - self.scaleFacMapping[expr.idf] = self.scaleFac - return (prog, expr) - - def visitId(self, node:AST.ID, args=None): - idf = node.name - prog = IR.Prog([]) - expr = IR.Var(idf) - if not(Util.Config.disableTruncOpti): - assert(expr.idf in self.scaleFacMapping) - return (prog, expr) - - def visitDecl(self, node:AST.Decl, args=None): - def helperAssignGen(l1, l2, allComb): - if l2 == []: - allComb.append(l1) - else: - for cur in range(l2[0]): - helperAssignGen(l1 + [cur], l2[1:], allComb) - - prog = IR.Prog([]) - expr = self.getTempVar() - expr.inputVar = True - - # If there is a valueList, then add assignment commands - specialBitLen = -1 - if node.valueList: - # Add assignment statements for each element of the tensor in a different array - comment = IR.Comment(str(node.metadata)) - prog = IRUtil.prog_merge(prog, IR.Prog([comment, IR.Comment('Element assignments for ' + expr.idf)])) - allComb = [] - helperAssignGen([], node.shape, allComb) - for i,curComb in enumerate(allComb): - curVal = node.valueList[i] - finalVal = None - if isinstance(curVal, AST.Int): - finalVal = IR.Int(curVal.value, curVal.bitLen) - if (specialBitLen == -1 and curVal.bitLen != Util.Config.wordLength): - specialBitLen = curVal.bitLen - elif isinstance(curVal, AST.Float): - finalVal = IR.DataType.getInt(np.ldexp(curVal.value, Util.Config.consSF)) - else: - # Assuming the elements can only be either int or floats - assert False - prog = IRUtil.prog_merge(prog, IR.Prog([IR.Assn(IRUtil.addIndex(expr, list(map(lambda x: IR.Int(x), curComb))), - finalVal)])) - - prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, - node.type, - Util.Config.wordLength if specialBitLen == -1 else specialBitLen, - node.isSecret - )]), prog) - - if not(Util.Config.disableTruncOpti): - self.scaleFacMapping[expr.idf] = self.scaleFac if node.isScaled else 0 - - return (prog, expr) - - def visitTranspose(self, node:AST.Transpose, args=None): - (inp_prog, inp_arr) = self.visit(node.expr) - inp_type = node.expr.type - out_type = node.type - inp_iters = self.getTempIterators(inp_type.dim) - out_iters = [] - perm = node.perm - if (perm is None): - perm = [i for i in reversed(range(len(inp_type.shape)))] - for i in perm: - out_iters.append(inp_iters[i]) - out_arr = self.getTempVar() - out_arr_expr = IRUtil.addIndex(out_arr, out_iters) - inp_arr_expr = IRUtil.addIndex(inp_arr, inp_iters) - assign_expr = IR.Assn(out_arr_expr, inp_arr_expr) - loop = IRUtil.loop(inp_type.shape, inp_iters, [assign_expr]) - # Finalize - comment1 = IR.Comment(str(node.metadata)) - comment2 = IR.Comment("transpose(" + inp_arr.idf + ", [" + ', '.join(str(e) for e in inp_type.shape) + "] --> [" + ', '.join(str(e) for e in out_type.shape) + "])") - transpose_prog = IR.Prog([comment1, comment2] + loop) - final_prog = IRUtil.prog_merge(inp_prog, transpose_prog) - - for var in inp_iters: - final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), final_prog) - final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(out_arr.idf, out_type)]), final_prog) - - if not(Util.Config.disableTruncOpti): - self.scaleFacMapping[out_arr.idf] = self.scaleFacMapping[inp_arr.idf] - - return (final_prog, out_arr) - - def visitSlice(self, node:AST.Slice, args=None): - (inp_prog, inp_arr) = self.visit(node.expr) - inp_type = node.expr.type - out_type = node.type - out_iters = self.getTempIterators(out_type.dim) - inp_iters = [] - subscriptRanges = node.subscriptRanges - for idx,subrange in enumerate(subscriptRanges): - start = subrange[0] - inp_iters.append(IRUtil.add(out_iters[idx], IR.Int(start))) - - out_arr = self.getTempVar() - out_arr_expr = IRUtil.addIndex(out_arr, out_iters) - inp_arr_expr = IRUtil.addIndex(inp_arr, inp_iters) - assign_expr = IR.Assn(out_arr_expr, inp_arr_expr) - loop = IRUtil.loop(out_type.shape, out_iters, [assign_expr]) - # Finalize - comment1 = IR.Comment(str(node.metadata)) - comment2 = IR.Comment("slice(" + inp_arr.idf + ", [" + ', '.join(str(e) for e in inp_type.shape) + "] --> [" + ', '.join(str(e) for e in out_type.shape) + "])") - slice_prog = IR.Prog([comment1, comment2] + loop) - final_prog = IRUtil.prog_merge(inp_prog, slice_prog) - - for var in out_iters: - final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), final_prog) - final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(out_arr.idf, out_type)]), final_prog) - - if not(Util.Config.disableTruncOpti): - self.scaleFacMapping[out_arr.idf] = self.scaleFacMapping[inp_arr.idf] - - return (final_prog, out_arr) - - def visitReshape(self, node:AST.Reshape, args=None): - (prog_1, expr_1) = self.visit(node.expr) - - ''' + varNameDelim = "" + + def __init__(self, intPartBitwidth=-1): + # For tracking temp variables + self._var_cnt = 0 + self._iter_cnt = 0 + + # Global variables + # Used to note declarations which will go before any statements + # But since this affects memory consumption, use carefully + self.globalDecls = ( + {} + ) # Mapping of (identifier name (string) -> list of [type, secret/public variable, bitlen of decl]) + # The 2nd arg can be either 'secret' or 'public'. + # If public/secret unspecified, default to 'secret'. + # The 3rd arg is used to specify the bitlen of the decl. + + # Name mapping from SeeDot names to new names is useful for debugging + self.name_mapping = {} + + self.actualbitwidth = Util.Config.actualWordLength + + # This is for optimizing the #truncation calls + self.scaleFac = Util.Config.consSF + self.bitwidth = Util.Config.wordLength + self.intPartBitwidth = intPartBitwidth + if self.intPartBitwidth == -1: + self.intPartBitwidth = self.bitwidth - 2 * self.scaleFac + self.scaleFacMapping = {} + + def getConsSF(self): + return Util.Config.consSF + + # Variable and iterators creation + def getTempVars(self, n: int): + return [self.getTempVar() for i in range(n)] + + def getTempVar(self): + var = IR.Var("tmp" + str(self._var_cnt)) + self._var_cnt += 1 + return var + + def getTempIterators(self, n: int): + return [self.getTempIterator() for i in range(n)] + + def getTempIterator(self): + var = IR.Var("i" + str(self._iter_cnt)) + self._iter_cnt += 1 + return var + + # Computing exponent and intervals + def get_expnt(self, maxabs: float): # -> int + return self.getConsSF() + + def addTruncateFunctionCallHelper( + self, exprNumToScaleDown: int, expr1: IR.Var, expr2: IR.Var, node: AST.BOp + ): + assert isinstance(node, AST.BOp) + if exprNumToScaleDown == 1: + exprToScaleDown = expr1 + nodeToScaleDown = node.expr1 + else: + assert exprNumToScaleDown == 2 + exprToScaleDown = expr2 + nodeToScaleDown = node.expr2 + return (exprToScaleDown, nodeToScaleDown) + + def addTruncateFunctionCall( + self, node: AST.ASTNode, nodeTypeStr: str, expr: IR.Var, consSF: int + ): + comment = IR.Comment("Truncation before {0} node.".format(nodeTypeStr)) + argsDict = OrderedDict() + funcName = "ScaleDown" + if not (Type.isInt(node.type)): + outputShape = node.type.shape + for ii, curDimSize in enumerate(outputShape): + argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) + funcName = funcName + str(len(outputShape)) + argsDict[expr] = "expr" + argsDict[IR.Int(consSF, 32)] = "consSF" + funcCall = IR.FuncCall(funcName, argsDict) + prog = IR.Prog([comment, funcCall]) + return prog + + def isModel(self, node: AST.ASTNode): + if node.type.taint == Type.Taints.SERVER: + return True + else: + return False + + # ================= + # Visit Functions + # ================= + + def visitInt(self, node: AST.Int, args=None): + n = node.value + prog = IR.Prog([IR.Comment("Int node, isSecret = {0}.".format(node.isSecret))]) + expr = self.getTempVar() + bitlen = -1 + if node.bitLen: + bitlen = node.bitLen + prog = IRUtil.prog_merge( + IR.Prog([IR.Decl(expr.idf, node.type, bitlen, node.isSecret, [n])]), prog + ) + if not (Util.Config.disableTruncOpti): + self.scaleFacMapping[expr.idf] = self.scaleFac if node.isScaled else 0 + return (prog, expr) + + def visitFloat(self, node: AST.Float, args=None): + r = node.value + p = self.get_expnt(abs(r)) + k = IR.DataType.getInt(np.ldexp(r, p)) + expr = self.getTempVar() + prog = IR.Prog( + [ + IR.Comment( + "Float to int : {0} to {1}, isSecret = {2}.".format( + str(r), str(k), node.isSecret + ) + ) + ] + ) + prog = IRUtil.prog_merge( + IR.Prog([IR.Decl(expr.idf, node.type, -1, node.isSecret, [k])]), prog + ) + if not (Util.Config.disableTruncOpti): + self.scaleFacMapping[expr.idf] = self.scaleFac + return (prog, expr) + + def visitId(self, node: AST.ID, args=None): + idf = node.name + prog = IR.Prog([]) + expr = IR.Var(idf) + if not (Util.Config.disableTruncOpti): + assert expr.idf in self.scaleFacMapping + return (prog, expr) + + def visitDecl(self, node: AST.Decl, args=None): + def helperAssignGen(l1, l2, allComb): + if l2 == []: + allComb.append(l1) + else: + for cur in range(l2[0]): + helperAssignGen(l1 + [cur], l2[1:], allComb) + + prog = IR.Prog([]) + expr = self.getTempVar() + expr.inputVar = True + + # If there is a valueList, then add assignment commands + specialBitLen = -1 + if node.valueList: + # Add assignment statements for each element of the tensor in a different array + comment = IR.Comment(str(node.metadata)) + prog = IRUtil.prog_merge( + prog, + IR.Prog([comment, IR.Comment("Element assignments for " + expr.idf)]), + ) + allComb = [] + helperAssignGen([], node.shape, allComb) + for i, curComb in enumerate(allComb): + curVal = node.valueList[i] + finalVal = None + if isinstance(curVal, AST.Int): + finalVal = IR.Int(curVal.value, curVal.bitLen) + if specialBitLen == -1 and curVal.bitLen != Util.Config.wordLength: + specialBitLen = curVal.bitLen + elif isinstance(curVal, AST.Float): + finalVal = IR.DataType.getInt( + np.ldexp(curVal.value, Util.Config.consSF) + ) + else: + # Assuming the elements can only be either int or floats + assert False + prog = IRUtil.prog_merge( + prog, + IR.Prog( + [ + IR.Assn( + IRUtil.addIndex( + expr, list(map(lambda x: IR.Int(x), curComb)) + ), + finalVal, + ) + ] + ), + ) + + prog = IRUtil.prog_merge( + IR.Prog( + [ + IR.Decl( + expr.idf, + node.type, + Util.Config.wordLength + if specialBitLen == -1 + else specialBitLen, + node.isSecret, + ) + ] + ), + prog, + ) + + if not (Util.Config.disableTruncOpti): + self.scaleFacMapping[expr.idf] = self.scaleFac if node.isScaled else 0 + + return (prog, expr) + + def visitTranspose(self, node: AST.Transpose, args=None): + (inp_prog, inp_arr) = self.visit(node.expr) + inp_type = node.expr.type + out_type = node.type + inp_iters = self.getTempIterators(inp_type.dim) + out_iters = [] + perm = node.perm + if perm is None: + perm = [i for i in reversed(range(len(inp_type.shape)))] + for i in perm: + out_iters.append(inp_iters[i]) + out_arr = self.getTempVar() + out_arr_expr = IRUtil.addIndex(out_arr, out_iters) + inp_arr_expr = IRUtil.addIndex(inp_arr, inp_iters) + assign_expr = IR.Assn(out_arr_expr, inp_arr_expr) + loop = IRUtil.loop(inp_type.shape, inp_iters, [assign_expr]) + # Finalize + comment1 = IR.Comment(str(node.metadata)) + comment2 = IR.Comment( + "transpose(" + + inp_arr.idf + + ", [" + + ", ".join(str(e) for e in inp_type.shape) + + "] --> [" + + ", ".join(str(e) for e in out_type.shape) + + "])" + ) + transpose_prog = IR.Prog([comment1, comment2] + loop) + final_prog = IRUtil.prog_merge(inp_prog, transpose_prog) + + for var in inp_iters: + final_prog = IRUtil.prog_merge( + IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), final_prog + ) + final_prog = IRUtil.prog_merge( + IR.Prog([IR.Decl(out_arr.idf, out_type)]), final_prog + ) + + if not (Util.Config.disableTruncOpti): + self.scaleFacMapping[out_arr.idf] = self.scaleFacMapping[inp_arr.idf] + + return (final_prog, out_arr) + + def visitSlice(self, node: AST.Slice, args=None): + (inp_prog, inp_arr) = self.visit(node.expr) + inp_type = node.expr.type + out_type = node.type + out_iters = self.getTempIterators(out_type.dim) + inp_iters = [] + subscriptRanges = node.subscriptRanges + for idx, subrange in enumerate(subscriptRanges): + start = subrange[0] + inp_iters.append(IRUtil.add(out_iters[idx], IR.Int(start))) + + out_arr = self.getTempVar() + out_arr_expr = IRUtil.addIndex(out_arr, out_iters) + inp_arr_expr = IRUtil.addIndex(inp_arr, inp_iters) + assign_expr = IR.Assn(out_arr_expr, inp_arr_expr) + loop = IRUtil.loop(out_type.shape, out_iters, [assign_expr]) + # Finalize + comment1 = IR.Comment(str(node.metadata)) + comment2 = IR.Comment( + "slice(" + + inp_arr.idf + + ", [" + + ", ".join(str(e) for e in inp_type.shape) + + "] --> [" + + ", ".join(str(e) for e in out_type.shape) + + "])" + ) + slice_prog = IR.Prog([comment1, comment2] + loop) + final_prog = IRUtil.prog_merge(inp_prog, slice_prog) + + for var in out_iters: + final_prog = IRUtil.prog_merge( + IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), final_prog + ) + final_prog = IRUtil.prog_merge( + IR.Prog([IR.Decl(out_arr.idf, out_type)]), final_prog + ) + + if not (Util.Config.disableTruncOpti): + self.scaleFacMapping[out_arr.idf] = self.scaleFacMapping[inp_arr.idf] + + return (final_prog, out_arr) + + def visitReshape(self, node: AST.Reshape, args=None): + (prog_1, expr_1) = self.visit(node.expr) + + """ reshape(A, n, h, w) cmd1: t1 = t2 = t3 = 0; @@ -280,965 +350,1285 @@ def visitReshape(self, node:AST.Reshape, args=None): if (t2 == HH) t2 = 0; t1++; - ''' - - typ_1 = node.expr.type - typ_2 = node.type - - # Declare variables - expr_2 = self.getTempVar() - iters_1 = self.getTempIterators(typ_1.dim) - iters_2 = self.getTempIterators(typ_2.dim) - - # Initialize to 0 - cmd1 = [IR.Assn(var, IRUtil.zero) for var in iters_1] - - # Incrementing the first index - first_iter = iters_1[0] - cmd4 = IRUtil.incCmd(first_iter) - - # Incrementing other indices using a loop - cmd5 = [cmd4] - for i in range(1, typ_1.dim): - curr_iter = iters_1[i] - curr_size = IR.Int(typ_1.shape[i]) - cmd5 = [IRUtil.incCmd(curr_iter), IR.If(IRUtil.eq(curr_iter, curr_size), [IRUtil.initVarToZero(curr_iter)] + cmd5)] - - # Outer loop - # The iterators are selected based on the selection order specified by the user - loopShape = [] - loopIters = [] - - if(node.order): - for order in node.order: - order = order - 1 - loopShape.append(typ_2.shape[order]) - loopIters.append(iters_2[order]) - else: - loopShape = typ_2.shape - loopIters = iters_2 - - loop2 = IRUtil.loop(loopShape, loopIters, [IR.Assn(IRUtil.addIndex(expr_2, iters_2), IRUtil.addIndex(expr_1, iters_1))] + cmd5) - - # Finalize - comment1 = IR.Comment(str(node.metadata)) - comment2 = IR.Comment("reshape(" + expr_1.idf + ", " + ', '.join(str(e) for e in typ_2.shape) + ")") - reshape_prog = IR.Prog([comment1, comment2] + cmd1 + loop2) - prog_2 = IRUtil.prog_merge(prog_1, reshape_prog) - - for var in iters_1: - prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), prog_2) - for var in iters_2: - prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), prog_2) - prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, typ_2)]), prog_2) - - if not(Util.Config.disableTruncOpti): - self.scaleFacMapping[expr_2.idf] = self.scaleFacMapping[expr_1.idf] - - return (prog_2, expr_2) - - def visitPool(self, node:AST.Pool, args=None): - (prog_1, expr_1) = self.visit(node.expr) - - [N, H, W, CI] = node.expr.type.shape - [N1, outH, outW, CI1] = node.type.shape - assert(N==N1 and CI==CI1) - [expr_2] = self.getTempVars(1) - - comment = IR.Comment(str(node.metadata)) - funcCallArgsDict = OrderedDict() - funcCallArgsDict[IR.Int(N1, 32)] = "N1" - funcCallArgsDict[IR.Int(outH, 32)] = "outH" - funcCallArgsDict[IR.Int(outW, 32)] = "outW" - funcCallArgsDict[IR.Int(CI1, 32)] = "CI1" - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.FH], 32)] = "FH" - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.FW], 32)] = "FW" - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHLeft], 32)] = "zPadHLeft" - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHRight], 32)] = "zPadHRight" - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWLeft], 32)] = "zPadWLeft" - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWRight], 32)] = "zPadWRight" - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideH], 32)] = "strideH" - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideW], 32)] = "strideW" - funcCallArgsDict[IR.Int(N, 32)] = "N" - funcCallArgsDict[IR.Int(H, 32)] = "H" - funcCallArgsDict[IR.Int(W, 32)] = "W" - funcCallArgsDict[IR.Int(CI, 32)] = "CI" - - funcCallArgsDict[expr_1] = "input" - funcCallArgsDict[expr_2] = "output" - - funcCall = IR.FuncCall(node.poolType, funcCallArgsDict) - prog_pool = IR.Prog([comment, funcCall]) - prog_2 = IRUtil.prog_merge(prog_1, prog_pool) - prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, node.type)]), prog_2) - - if not(Util.Config.disableTruncOpti): - self.scaleFacMapping[expr_2.idf] = self.scaleFacMapping[expr_1.idf] - - return (prog_2, expr_2) - - def visitUOp(self, node:AST.UOp, args=None): - (prog_1, expr_1) = self.visit(node.expr) - op = node.op - if op == AST.Operators.ADD: - return (prog_1, expr_1) - assert op == AST.Operators.SUB - - typ_2 = node.type - expr_2 = self.getTempVar() - - if Type.isInt(typ_2): - comment = IR.Comment(str(node.metadata)) - bitlen = node.expr.bitlen - decl = IR.Decl(expr_2.idf, node.type, typ_2.bitlen, typ_2.isSecret) - assign = IR.Assn(expr_2, IRUtil.negate(expr_1)) - prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment, decl, assign])) - else: - # decl fresh vars - iters = self.getTempIterators(typ_2.dim) - - # cmdl_assn - expr_1_elt = IRUtil.addIndex(expr_1, iters) - expr_2_elt = IRUtil.addIndex(expr_2, iters) - cmdl_assn = IRUtil.loop(typ_2.shape, iters, [IR.Assn(expr_2_elt, IRUtil.sub(IRUtil.zero, expr_1_elt))]) - comment = IR.Comment(str(node.metadata)) - prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment] + cmdl_assn)) - prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, node.type)]), prog_2) - - if not(Util.Config.disableTruncOpti): - self.scaleFacMapping[expr_2.idf] = self.scaleFacMapping[expr_1.idf] - - return (prog_2, expr_2) - - def visitBOp(self, node:AST.BOp, args=None): - op = node.op - if (op in [AST.Operators.ADD, AST.Operators.SUB, AST.Operators.Equal]): return self.visitBopAddOrSubLike(node) - elif (op in [AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv]): return self.visitBopElemWiseOp(node) - elif op == AST.Operators.MUL: return self.visitBopMul(node) - elif op == AST.Operators.CONV: return self.visitBopConv(node) - elif op == AST.Operators.CONVTRANSPOSE: return self.visitBopConvTranspose(node) - else: assert False - - def visitBopAddOrSubLike(self, node:AST.BOp, args=None): - (prog_1, expr_1) = self.visit(node.expr1) - (prog_2, expr_2) = self.visit(node.expr2) - - op = node.op - if (op == AST.Operators.ADD): - op_ir = IR.Op.Op['+'] - elif (op == AST.Operators.SUB): - op_ir = IR.Op.Op['-'] - elif (op == AST.Operators.Equal): - op_ir = IR.Op.Op['=='] - else: - assert False - - node_type = node.type - out_arr = self.getTempVar() - cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf) - comment = IR.Comment(str(node.metadata)) - - if not(Util.Config.disableTruncOpti): - expr1_sf = self.scaleFacMapping[expr_1.idf] - expr2_sf = self.scaleFacMapping[expr_2.idf] - scaleUpFactor = -1 - if (expr1_sf > expr2_sf): - exprToScale = expr_2 - typeOfExprToScale = node.expr2.type - scaleUpFactor = expr1_sf - expr2_sf - self.scaleFacMapping[expr_2.idf] = expr1_sf - elif (expr2_sf > expr1_sf): - exprToScale = expr_1 - typeOfExprToScale = node.expr1.type - scaleUpFactor = expr2_sf - expr1_sf - self.scaleFacMapping[expr_1.idf] = expr2_sf - - if scaleUpFactor!=-1: - comm = IR.Comment('Scale up of args needed was found while doing OptimizeTruncations.') - argsDict = OrderedDict() - curFuncName = "ScaleUp" - if not(Type.isInt(typeOfExprToScale)): - outputShape = typeOfExprToScale.shape - for ii,curDimSize in enumerate(outputShape): - argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) - curFuncName += str(len(outputShape)) - argsDict[exprToScale] = "exprToScale, arg#{0}".format(2 if (expr1_sf>expr2_sf) else 1) - argsDict[IR.Int(scaleUpFactor, 32)] = "ScaleUpFactor" - funcCall = IR.FuncCall(curFuncName, argsDict) - - if Type.isInt(typeOfExprToScale) or typeOfExprToScale.shape == []: - assn_expr = IR.Assn(exprToScale, funcCall) - curProg = IR.Prog([comm,assn_expr]) - else: - curProg = IR.Prog([comm,funcCall]) - prog_1 = IRUtil.prog_merge(curProg, prog_1) - - self.scaleFacMapping[out_arr.idf] = self.scaleFacMapping[expr_1.idf] - - decl = IR.Decl(out_arr.idf, node_type, node_type.bitlen, node_type.isSecret) - - if Type.isInt(node_type): - assign = IR.Assn(out_arr, IR.IntBop(expr_1, op_ir, expr_2)) - out_prog = IR.Prog([assign]) - else: - outputShape = node_type.shape - inp1_shape = [] if Type.isInt(node.expr1.type) else node.expr1.type.shape - inp2_shape = [] if Type.isInt(node.expr2.type) else node.expr2.type.shape - - expected_output_shape, _, _ = Util.getBroadcastShapes(inp1_shape, inp2_shape) - assert(outputShape == expected_output_shape) - out_prog = IRUtil.generateBroadcastLoopBOp(expr_1, inp1_shape, expr_2, inp2_shape, out_arr, op_ir) - - out_prog = IRUtil.prog_merge(IR.Prog([comment, cmd0, decl]), out_prog) - out_prog = IRUtil.prog_merge(prog_1, prog_2, out_prog) - return (out_prog, out_arr) - - - # We first reshape both inputs and flatten them into 1d dims. - # For simplicity consider a non-broadcast example: - # inputs : inp1_arr[s1][s2], inp2_arr[s1][s2] - # after flattening : inp1_arr_flat[s1*s2], inp2_arr_flat[s1*s2] - # for i1=[0:s1] - # for i2=[0:s2] - # idx = i1*s2 + i2 - # inp1_arr_flat[idx] = inp1_arr[i1][i2] - # inp2_arr_flat[idx] = inp2_arr[i1][i2] - # If one input is from server and the other from model we can call an optimized version of mul - # ElemWiseActModelVectorMult(s1*s2, inp1_arr_flat, inp2_arr_flat, out_arr_flat) <- optimized - # OR - # ElemWiseSecretSharedVectorMult(s1*s2, inp1_arr_flat, inp2_arr_flat, out_arr_flat) - # Finally we reshape the flattened output - # for i1=[0:s1] - # for i2=[0:s2] - # idx = i1*s2 + i2 - # out_arr[i1][i2] = out_arr_flat[idx] - # Standard broadcast rules apply to generate these flattened tensors. - def visitBopElemWiseOp(self, node:AST.BOp, args=None): - (prog_1, expr_1) = self.visit(node.expr1) - (prog_2, expr_2) = self.visit(node.expr2) - - if (node.op == AST.Operators.ElemWiseMul): - op_ir = IR.Op.Op['.*'] - funcName = "ElemWiseMul" - elif (node.op == AST.Operators.ElemWiseDiv): - op_ir = IR.Op.Op['./'] - funcName = "ElemWiseDiv" - assert False, "Did not implement div yet" - else: - assert False, "Non mul/div elemwise op" - - comment = IR.Comment(str(node.metadata)) - cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf) - - node_type = node.type - # outArr[s1][s2] - out_arr = self.getTempVar() - decl_out_arr = IR.Decl(out_arr.idf, node_type, node_type.bitlen, node_type.isSecret) - - if Type.isInt(node_type): - assign = IR.Assn(out_arr, IR.IntBop(expr_1, op_ir, expr_2)) - out_prog = IR.Prog([assign]) - else: - # Flattening inputs - output_shape = node_type.shape - inp1_shape = [] if Type.isInt(node.expr1.type) else node.expr1.type.shape - inp2_shape = [] if Type.isInt(node.expr2.type) else node.expr2.type.shape - out_iters = self.getTempIterators(len(output_shape)) - expected_output_shape, broadcast_mask_1, broadcast_mask_2 = Util.getBroadcastShapes(inp1_shape, inp2_shape) - assert(expected_output_shape == output_shape) - - # inp1_arr[i1][i2], inp2_arr[i1][i2], out_arr[i1][i2] - inp1_iters = IRUtil.getMaskedIters(broadcast_mask_1, out_iters, inp1_shape) - inp2_iters = IRUtil.getMaskedIters(broadcast_mask_2, out_iters, inp2_shape) - inp1_arr_expr = IRUtil.addIndex(expr_1, inp1_iters) - inp2_arr_expr = IRUtil.addIndex(expr_2, inp2_iters) - out_arr_expr = IRUtil.addIndex(out_arr, out_iters) - - flat_size = Util.get_volume(output_shape) - inp1_arr_flat = self.getTempVar() - inp2_arr_flat = self.getTempVar() - out_arr_flat = self.getTempVar() - flat_type = Type.Tensor([flat_size], node.expr1.type.bitlen, node.expr1.type.isSecret, node.expr1.type.taint) - # inp1_arr_flat[s1*s2] - # inp2_arr_flat[s1*s2] - # out_arr_flat[s1*s2] - decl_inp1_arr_flat = IR.Decl(inp1_arr_flat.idf, flat_type, node.expr1.type.bitlen, node.expr1.type.isSecret) - decl_inp2_arr_flat = IR.Decl(inp2_arr_flat.idf, flat_type, node.expr2.type.bitlen, node.expr2.type.isSecret) - decl_out_arr_flat = IR.Decl(out_arr_flat.idf, flat_type, node.type.bitlen, node.type.isSecret) - # idx - flat_idx = self.getTempVar() - decl_flat_idx = IR.Decl(flat_idx.idf, Type.Int(bitlen=32), bitlen=32, isSecret=False) - # For 4d, generate (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + (i4); - flat_idx_expr = IR.Int(0,32) - for i in range(len(out_iters)): - vol = Util.get_volume(output_shape[i+1:]) - flat_idx_expr = IRUtil.add(flat_idx_expr, IRUtil.mul(out_iters[i], IR.Int(vol,32))) - # inp1_arr_flat[idx], inp2_arr_flat[idx], out_arr_flat[idx] - inp1_arr_flat_expr = IRUtil.addIndex(inp1_arr_flat, [flat_idx]) - inp2_arr_flat_expr = IRUtil.addIndex(inp2_arr_flat, [flat_idx]) - out_arr_flat_expr = IRUtil.addIndex(out_arr_flat, [flat_idx]) - # idx = i1*s2 + i2; - # inp1_arr_flat[idx] = inp1_arr[i1][i2] - # inp2_arr_flat[idx] = inp2_arr[i1][i2] - assign_flat_idx_expr = IR.Assn(flat_idx, flat_idx_expr) - assign_inp1_arr_flat = IR.Assn(inp1_arr_flat_expr, inp1_arr_expr) - assign_inp2_arr_flat = IR.Assn(inp2_arr_flat_expr, inp2_arr_expr) - # Flattening loop - # for i1=[0:s1] - # for i2=[0:s2] - # idx = i1*s2 + i2 - # inp1_arr_flat[idx] = inp1_arr[i1][i2] - # inp2_arr_flat[idx] = inp2_arr[i1][i2] - out_loop = IRUtil.loop(output_shape, out_iters, [assign_flat_idx_expr, assign_inp1_arr_flat, assign_inp2_arr_flat]) - out_prog = IRUtil.Prog(out_loop) - decls = [decl_out_arr, decl_inp1_arr_flat, decl_inp2_arr_flat, decl_out_arr_flat, decl_flat_idx] - out_prog = IRUtil.prog_merge(IRUtil.Prog(decls), out_prog) - - # Insert call to mul/div functionality - argsDict = OrderedDict() - argsDict[IR.Int(flat_size, 32)] = "input_shape" - if (node.op == AST.Operators.ElemWiseDiv): - argsDict[inp1_arr_flat] = "A" - argsDict[inp2_arr_flat] = "B" - funcName = "ElemwiseSuperDuperSecretDiv" - assert False, "Elemwise div not implemented" - else: - # If either input is a model weight we can use an optimised version for mul - # Otherwise if both are derived from client input we use the hadmaard version - isMulOptimised = False - if not(self.isModel(node.expr1)) and not(self.isModel(node.expr2)): - argsDict[inp1_arr_flat] = "A" - argsDict[inp2_arr_flat] = "B" - else: - isMulOptimised = True - # Optimised version expects the second parameter to be an input from server - if self.isModel(node.expr2): - argsDict[inp1_arr_flat] = "A" - argsDict[inp2_arr_flat] = "B" - else: - # Shuffle the params. - argsDict[inp2_arr_flat] = "A" - argsDict[inp1_arr_flat] = "B" - funcName = "ElemWiseActModelVectorMult" if isMulOptimised else "ElemWiseSecretSharedVectorMult" - argsDict[out_arr_flat] = "Output" - funcCall = IR.FuncCall(funcName, argsDict) - out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) - - # Clear temp arrays - argsDict = OrderedDict() - argsDict[IR.Int(flat_size, 32)] = "size" - argsDict[inp1_arr_flat] = "A" - funcCall = IR.FuncCall("ClearMemSecret1", argsDict) - out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) - argsDict = OrderedDict() - argsDict[IR.Int(flat_size, 32)] = "size" - argsDict[inp2_arr_flat] = "A" - funcCall = IR.FuncCall("ClearMemSecret1", argsDict) - out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) - - # Unflatten output - assign_out_arr_flat = IR.Assn(out_arr_expr, out_arr_flat_expr) - out_loop = IRUtil.loop(output_shape, out_iters, [assign_flat_idx_expr, assign_out_arr_flat]) - out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog(out_loop)) - - argsDict = OrderedDict() - argsDict[IR.Int(flat_size, 32)] = "size" - argsDict[out_arr_flat] = "A" - funcCall = IR.FuncCall("ClearMemSecret1", argsDict) - out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) - - progExtraBefore = IR.Prog([]) - progExtraAfter = IR.Prog([]) - if (Util.Config.disableTruncOpti): - progExtraAfter = self.addTruncateFunctionCall(node, "ElemWiseMul", out_arr, Util.Config.consSF) - else: - inputs_same = (expr_1.idf == expr_2.idf) - expr1_sf = self.scaleFacMapping[expr_1.idf] - expr2_sf = self.scaleFacMapping[expr_2.idf] - if (expr1_sf > self.scaleFac): - progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ElemWiseMul", expr_1, expr1_sf - self.scaleFac) - self.scaleFacMapping[expr_1.idf] = self.scaleFac - if (not inputs_same) and (expr2_sf > self.scaleFac): - progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ElemWiseMul", expr_2, expr2_sf - self.scaleFac)) - self.scaleFacMapping[expr_2.idf] = self.scaleFac - self.scaleFacMapping[out_arr.idf] = 2*self.scaleFac - - out_prog = IRUtil.prog_merge(IRUtil.Prog([comment, cmd0]), progExtraBefore, out_prog, progExtraAfter) - return (out_prog, out_arr) - - def visitBopMul(self, node:AST.BOp, args=None): - typ_1 = node.expr1.type - typ_2 = node.expr2.type - typ_3 = node.type - if (Type.isInt(typ_3)): return self.visitBopMulInt(node) - elif (typ_1.dim == 0 or Type.isInt(typ_1) or typ_2.dim == 0 or Type.isInt(typ_2)): return self.visitBopMulScalar1DTensor(node) - else: return self.visitBopMul2DTensor(node) - - def visitBopMulInt(self, node:AST.BOp, args=None): - (prog_1, expr_1) = self.visit(node.expr1) - (prog_2, expr_2) = self.visit(node.expr2) - - expr_3 = self.getTempVar() - comment = IR.Comment(str(node.metadata)) - bitlen = node.expr.bitlen - decl = IR.Decl(expr_3.idf, node.type, node.type.bitlen, node.type.isSecret) - assign = IR.Assn(expr_3, IRUtil.mul(expr_1, expr_2)) - prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, decl, assign])) - - progExtraBefore = IR.Prog([]) - progExtraAfter = IR.Prog([]) - if (Util.Config.disableTruncOpti): - progExtraAfter = self.addTruncateFunctionCall(node, "MulInt", expr_3, Util.Config.consSF) - else: - inputs_same = (expr_1.idf == expr_2.idf) - expr1_sf = self.scaleFacMapping[expr_1.idf] - expr2_sf = self.scaleFacMapping[expr_2.idf] - if (expr1_sf > self.scaleFac): - progExtraBefore = self.addTruncateFunctionCall(node.expr1, "MulInt", expr_1, expr1_sf-self.scaleFac) - self.scaleFacMapping[expr_1.idf] = self.scaleFac - if (not inputs_same) and (expr2_sf > self.scaleFac): - progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "MulInt", expr_2, expr2_sf-self.scaleFac)) - self.scaleFacMapping[expr_2.idf] = self.scaleFac - self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac - - prog_3 = IRUtil.prog_merge(progExtraBefore, prog_3, progExtraAfter) - return (prog_3, expr_3) - - def visitBopMulScalar1DTensor(self, node:AST.BOp, args=None): - (prog_1, expr_1) = self.visit(node.expr1) - (prog_2, expr_2) = self.visit(node.expr2) - - typ_1 = node.expr1.type - typ_2 = node.expr2.type - typ_3 = node.type - - isIntMult = False - if typ_1.dim == 0 or Type.isInt(typ_1): - a, b = expr_1, expr_2 - outputShape = typ_2.shape - isIntMult = (Type.isInt(typ_1)) - else: - a, b = expr_2, expr_1 - outputShape = typ_1.shape - isIntMult = (Type.isInt(typ_2)) - - # decl fresh vars - expr_3 = self.getTempVar() - cmd0 = IR.Comment(expr_1.idf + ' * ' + expr_2.idf) - funcCallArgsDict = OrderedDict() - for ii,curDimSize in enumerate(outputShape): - funcCallArgsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) - funcCallArgsDict[a] = "A" - funcCallArgsDict[b] = "B" - funcCallArgsDict[expr_3] = "C" - progExtraBefore = IR.Prog([]) - progExtraAfter = IR.Prog([]) - if (Util.Config.disableTruncOpti): - progExtraAfter = self.addTruncateFunctionCall(node, "ScalarMul", expr_3, Util.Config.consSF) - else: - inputs_same = (expr_1.idf == expr_2.idf) - expr1_sf = self.scaleFacMapping[expr_1.idf] - expr2_sf = self.scaleFacMapping[expr_2.idf] - if (expr1_sf > self.scaleFac): - progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ScalarMul", expr_1, expr1_sf-self.scaleFac) - self.scaleFacMapping[expr_1.idf] = self.scaleFac - if (not inputs_same) and (expr2_sf > self.scaleFac): - progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ScalarMul", expr_2, expr2_sf-self.scaleFac)) - self.scaleFacMapping[expr_2.idf] = self.scaleFac - self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac - - funcCall = IR.FuncCall('ScalarMul' + self.varNameDelim + str(len(outputShape)), funcCallArgsDict) - prog_3 = IRUtil.prog_merge(prog_1, prog_2, progExtraBefore, IR.Prog([cmd0, funcCall])) - prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3, progExtraAfter) - return (prog_3, expr_3) - - def visitBopMul2DTensor(self, node:AST.BOp, args=None): - (prog_1, expr_1) = self.visit(node.expr1) - (prog_2, expr_2) = self.visit(node.expr2) - - # decl fresh vars - expr_3 = self.getTempVar() - - typ_1 = node.expr1.type - typ_2 = node.expr2.type - typ_3 = node.type - - [I, J] = typ_1.shape - [J, K] = typ_2.shape - typ_mul = Type.Tensor([J]) - - shrT = Util.Config.consSF - - cmd0 = IR.Comment(expr_1.idf + ' * ' + expr_2.idf) - funcCallArgsDict = OrderedDict() - funcCallArgsDict[IR.Int(I, 32)] = "I" - funcCallArgsDict[IR.Int(J, 32)] = "J" - funcCallArgsDict[IR.Int(K, 32)] = "K" - funcCallArgsDict[expr_1] = "A" - funcCallArgsDict[expr_2] = "B" - funcCallArgsDict[expr_3] = "C" - - # Add an arg as to which arg out of A or B is a model weight - # This is ok, since Athos is right now tailored for neural network inference - # and in inference, in every linear layer, either of A or B will be a model weight. - # This is required because for some backends, knowing which of A or B is a model weight - # can make a difference in their performance. - - assert (self.isModel(node.expr1) or self.isModel(node.expr2)), "Expecting one of A or B to be an input by the server (model weight)." - modelIsA = True - if (not self.isModel(node.expr1)): - modelIsA = False - funcCallArgsDict[IR.Bool(modelIsA)] = "modelIsA" - - progExtraBefore = IR.Prog([]) - progExtraAfter = IR.Prog([]) - if (Util.Config.disableTruncOpti): - progExtraAfter = self.addTruncateFunctionCall(node, "MatMul2D", expr_3, Util.Config.consSF) - else: - inputs_same = (expr_1.idf == expr_2.idf) - expr1_sf = self.scaleFacMapping[expr_1.idf] - expr2_sf = self.scaleFacMapping[expr_2.idf] - if (expr1_sf > self.scaleFac): - progExtraBefore = self.addTruncateFunctionCall(node.expr1, "MatMul2D", expr_1, expr1_sf-self.scaleFac) - self.scaleFacMapping[expr_1.idf] = self.scaleFac - if (not inputs_same) and (expr2_sf > self.scaleFac): - progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "MatMul2D", expr_2, expr2_sf-self.scaleFac)) - self.scaleFacMapping[expr_2.idf] = self.scaleFac - self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac - - funcCall = IR.FuncCall("MatMul2D", funcCallArgsDict) - comment = IR.Comment(str(node.metadata)) - prog_3 = IRUtil.prog_merge(prog_1, prog_2, progExtraBefore, IR.Prog([comment, cmd0, funcCall])) - prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3, progExtraAfter) - - return (prog_3, expr_3) - - def visitBopConv(self, node:AST.BOp, args=None): - (prog1, expr_1) = self.visit(node.expr1) - (prog2, expr_2) = self.visit(node.expr2) - - convDim = 2 - if (AST.PaddingKeysDict.ConvDim in node.options): - convDim = node.options[AST.PaddingKeysDict.ConvDim] - - if convDim == 2: - [N, H, W, CI] = node.expr1.type.shape - [FH, FW, CI1, CO] = node.expr2.type.shape - elif convDim == 3: - [N, D, H, W, CI] = node.expr1.type.shape - [FD, FH, FW, CI1, CO] = node.expr2.type.shape - else: - assert(False) - - returnExpr = self.getTempVar() - comment = IR.Comment(expr_1.idf + ' # ' + expr_2.idf + ', convDim = ' + str(convDim)) - funcCallArgsDict = OrderedDict() - funcCallArgsDict[IR.Int(N, 32)] = "N" - if convDim == 3: - funcCallArgsDict[IR.Int(D, 32)] = "D" - funcCallArgsDict[IR.Int(H, 32)] = "H" - funcCallArgsDict[IR.Int(W, 32)] = "W" - funcCallArgsDict[IR.Int(CI, 32)] = "CI" - if convDim == 3: - funcCallArgsDict[IR.Int(FD, 32)] = "FD" - funcCallArgsDict[IR.Int(FH, 32)] = "FH" - funcCallArgsDict[IR.Int(FW, 32)] = "FW" - funcCallArgsDict[IR.Int(CO, 32)] = "CO" - if convDim == 3: - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadDLeft], 32)] = "zPadDLeft" - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadDRight], 32)] = "zPadDRight" - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHLeft], 32)] = "zPadHLeft" - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHRight], 32)] = "zPadHRight" - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWLeft], 32)] = "zPadWLeft" - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWRight], 32)] = "zPadWRight" - if convDim == 3: - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideD], 32)] = "strideD" - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideH], 32)] = "strideH" - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideW], 32)] = "strideW" - - isGroupConv = False - if AST.PaddingKeysDict.group in node.options.keys(): - funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.group], 32)] = "G" - isGroupConv = True - - funcCallArgsDict[expr_1] = "input" - funcCallArgsDict[expr_2] = "filter" - if convDim == 3: - funcCallArgsDict[IR.Int(Util.Config.consSF, 32)] = "consSF" - funcCallArgsDict[returnExpr] = "output" - - if convDim == 2: - funcCallName = "Conv2D" - else: - funcCallName = "Conv3D" - - if isGroupConv: - funcCallName += "Group" - - funcCallName += "Wrapper" - - funcCall = IR.FuncCall(funcCallName, funcCallArgsDict) - progConv = IR.Prog([comment, funcCall]) - - progExtraBefore = IR.Prog([]) - progExtraAfter = IR.Prog([]) - if (Util.Config.disableTruncOpti): - progExtraAfter = self.addTruncateFunctionCall(node, "Conv", returnExpr, Util.Config.consSF) - else: - inputs_same = (expr_1.idf == expr_2.idf) - expr1_sf = self.scaleFacMapping[expr_1.idf] - expr2_sf = self.scaleFacMapping[expr_2.idf] - if (expr1_sf > self.scaleFac): - progExtraBefore = self.addTruncateFunctionCall(node.expr1, "Conv", expr_1, expr1_sf-self.scaleFac) - self.scaleFacMapping[expr_1.idf] = self.scaleFac - if (not inputs_same) and (expr2_sf > self.scaleFac): - progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "Conv", expr_2, expr2_sf-self.scaleFac)) - self.scaleFacMapping[expr_2.idf] = self.scaleFac - self.scaleFacMapping[returnExpr.idf] = 2*self.scaleFac - - returnProg = IRUtil.prog_merge(prog1, prog2, progExtraBefore, progConv) - returnProg = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg, progExtraAfter) - return (returnProg, returnExpr) - - def visitBopConvTranspose(self, node:AST.BOp, args=None): - (prog1, expr_1) = self.visit(node.expr1) - (prog2, expr_2) = self.visit(node.expr2) - - convDim = 2 - if (AST.PaddingKeysDict.ConvDim in node.options): - convDim = node.options[AST.PaddingKeysDict.ConvDim] - - if convDim==2: - [N, H_prime, W_prime, CI1] = node.expr1.type.shape - [FH, FW, CO, CI] = node.expr2.type.shape - elif convDim==3: - [N, D_prime, H_prime, W_prime, CI1] = node.expr1.type.shape - [FD, FH, FW, CO, CI] = node.expr2.type.shape - else: - assert(False) - assert(CI1 == CI) - - H = node.options[AST.PaddingKeysDict.outputImgH] #outputH - W = node.options[AST.PaddingKeysDict.outputImgW] #outputW - pad_h_total = node.options[AST.PaddingKeysDict.zPadHLeft] + node.options[AST.PaddingKeysDict.zPadHRight] - pad_w_total = node.options[AST.PaddingKeysDict.zPadWLeft] + node.options[AST.PaddingKeysDict.zPadWRight] - strideH = node.options[AST.PaddingKeysDict.strideH] - strideW = node.options[AST.PaddingKeysDict.strideW] - [pad_h_tr_total, stride_h_tr, h_prime_tilde] = AST.Operators.findConvTransposePadding(H, H_prime, FH, pad_h_total, strideH) - [pad_w_tr_total, stride_w_tr, w_prime_tilde] = AST.Operators.findConvTransposePadding(W, W_prime, FW, pad_w_total, strideW) - - [pad_h_tr_left, pad_h_tr_right] = AST.Operators.findLeftRightPaddingFromTotalPadding(pad_h_tr_total) - [pad_w_tr_left, pad_w_tr_right] = AST.Operators.findLeftRightPaddingFromTotalPadding(pad_w_tr_total) - - assert(AST.Operators.findConvOutputImgSize(h_prime_tilde, pad_h_tr_total, FH, stride_h_tr) == H) - assert(AST.Operators.findConvOutputImgSize(w_prime_tilde, pad_w_tr_total, FW, stride_w_tr) == W) - - if convDim == 3: - D = node.options[AST.PaddingKeysDict.outputImgD] #outputD - pad_d_total = node.options[AST.PaddingKeysDict.zPadDLeft] + node.options[AST.PaddingKeysDict.zPadDRight] - strideD = node.options[AST.PaddingKeysDict.strideD] - [pad_d_tr_total, stride_d_tr, d_prime_tilde] = AST.Operators.findConvTransposePadding(D, D_prime, FD, pad_d_total, strideD) - [pad_d_tr_left, pad_d_tr_right] = AST.Operators.findLeftRightPaddingFromTotalPadding(pad_d_tr_total) - assert(AST.Operators.findConvOutputImgSize(d_prime_tilde, pad_d_tr_total, FD, stride_d_tr) == D) - - returnExpr = self.getTempVar() - comment = IR.Comment(expr_1.idf + ' #T ' + expr_2.idf + ', convDim = ' + str(convDim)) - funcCallArgsDict = OrderedDict() - funcCallArgsDict[IR.Int(N, 32)] = "N" - if convDim==3: - funcCallArgsDict[IR.Int(D_prime, 32)] = "D_prime" - funcCallArgsDict[IR.Int(H_prime, 32)] = "H_prime" - funcCallArgsDict[IR.Int(W_prime, 32)] = "W_prime" - funcCallArgsDict[IR.Int(CI, 32)] = "CI" - if convDim==3: - funcCallArgsDict[IR.Int(FD, 32)] = "FD" - funcCallArgsDict[IR.Int(FH, 32)] = "FH" - funcCallArgsDict[IR.Int(FW, 32)] = "FW" - funcCallArgsDict[IR.Int(CO, 32)] = "CO" - if convDim==3: - funcCallArgsDict[IR.Int(D, 32)] = "D" - funcCallArgsDict[IR.Int(H, 32)] = "H" - funcCallArgsDict[IR.Int(W, 32)] = "W" - if convDim==3: - funcCallArgsDict[IR.Int(pad_d_tr_left, 32)] = "pad_d_tr_left" - funcCallArgsDict[IR.Int(pad_d_tr_right, 32)] = "pad_d_tr_right" - funcCallArgsDict[IR.Int(pad_h_tr_left, 32)] = "pad_h_tr_left" - funcCallArgsDict[IR.Int(pad_h_tr_right, 32)] = "pad_h_tr_right" - funcCallArgsDict[IR.Int(pad_w_tr_left, 32)] = "pad_w_tr_left" - funcCallArgsDict[IR.Int(pad_w_tr_right, 32)] = "pad_w_tr_right" - if convDim==3: - funcCallArgsDict[IR.Int(strideD, 32)] = "strideD" - funcCallArgsDict[IR.Int(strideH, 32)] = "strideH" - funcCallArgsDict[IR.Int(strideW, 32)] = "strideW" - - funcCallArgsDict[expr_1] = "input" - funcCallArgsDict[expr_2] = "filter" - if convDim == 3: - funcCallArgsDict[IR.Int(Util.Config.consSF, 32)] = "consSF" - funcCallArgsDict[returnExpr] = "output" - - if convDim == 2: - funcCallName = "ConvTranspose2D" - else: - funcCallName = "ConvTranspose3D" - funcCallName += "Wrapper" - funcCall = IR.FuncCall(funcCallName, funcCallArgsDict) - - progConv = IR.Prog([comment, funcCall]) - - progExtraBefore = IR.Prog([]) - progExtraAfter = IR.Prog([]) - if (Util.Config.disableTruncOpti): - progExtraAfter = self.addTruncateFunctionCall(node, "ConvTranspose", returnExpr, self.scaleFac) - else: - inputs_same = (expr_1.idf == expr_2.idf) - expr1_sf = self.scaleFacMapping[expr_1.idf] - expr2_sf = self.scaleFacMapping[expr_2.idf] - if (expr1_sf > self.scaleFac): - progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ConvTranspose", expr_1, expr1_sf-self.scaleFac) - self.scaleFacMapping[expr_1.idf] = self.scaleFac - if (not inputs_same) and (expr2_sf > self.scaleFac): - progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ConvTranspose", expr_2, expr2_sf-self.scaleFac)) - self.scaleFacMapping[expr2.idf] = self.scaleFac - self.scaleFacMapping[returnExpr.idf] = 2*self.scaleFac - - returnProg = IRUtil.prog_merge(prog1, prog2, progExtraBefore, progConv) - returnProg = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg, progExtraAfter) - return (returnProg, returnExpr) - - def visitFunc(self, node:AST.Func, args=None): - op = node.op - assert(op in [AST.Operators.Floor, AST.Operators.Shape, AST.Operators.RELU, AST.Operators.TANH, - AST.Operators.SIGMOID, AST.Operators.SQRT, AST.Operators.RSQRT, - AST.Operators.ClearMemSecret, AST.Operators.ClearMemPublic]) - return self.visitFloorLike(node) - - def visitFloorLike(self, node:AST.Func, args=None): - (prog1, expr1) = self.visit(node.expr) - out_expr = self.getTempVar() - - if node.op == AST.Operators.Floor: - funcName = "Floor" - elif node.op == AST.Operators.Shape: - funcName = "Shape" - elif node.op == AST.Operators.RELU: - funcName = "Relu" - elif node.op == AST.Operators.TANH: - funcName = "Tanh" - elif node.op == AST.Operators.SIGMOID: - funcName = "Sigmoid" - elif node.op == AST.Operators.SQRT: - funcName = "Sqrt" - elif node.op == AST.Operators.RSQRT: - funcName = "Sqrt" - elif node.op == AST.Operators.ClearMemSecret: - funcName = "ClearMemSecret" - elif node.op == AST.Operators.ClearMemPublic: - funcName = "ClearMemPublic" - else: - assert False - - # We don't need to clear scalars. - if node.op == AST.Operators.ClearMemSecret or node.op == AST.Operators.ClearMemPublic: - if Type.isInt(node.expr.type): - return (prog1, expr1) - if node.expr.type.dim == 0: - return (prog1, expr1) - - argsList = OrderedDict() - - inputType = node.expr.type - if Type.isTensor(inputType): - for ii, curDim in enumerate(inputType.shape): - argsList[IR.Int(curDim, 32)] = "inShape_" + str(ii) - argsList[expr1] = "inArr" - - if Type.isTensor(node.type): - argsList[out_expr] = "outArr" - - if node.op == AST.Operators.Floor: - argsList[IR.Int(Util.Config.consSF,32)] = "curScale" - - progExtraBefore = IR.Prog([]) - if (Util.Config.disableTruncOpti): - if node.op == AST.Operators.RELU: - argsList[IR.Int(Util.Config.consSF,32)] = "consSF" - argsList[IR.Bool(False)] = "doTruncation" - if node.op in [AST.Operators.TANH, AST.Operators.SIGMOID, AST.Operators.SQRT, AST.Operators.RSQRT]: - argsList[IR.Int(self.scaleFac,32)] = "sA" - argsList[IR.Int(self.scaleFac,32)] = "sB" - else: - final_sf = self.scaleFacMapping[expr1.idf] - if node.op == AST.Operators.RELU: - argsList[IR.Int(final_sf - self.scaleFac,32)] = "consSF" - if (final_sf > self.scaleFac): - #If it can't tolerate one more mult operation, then scale down here - assert(final_sf - self.scaleFac == self.scaleFac) - final_sf = self.scaleFac - argsList[IR.Bool(True)] = "doTruncation" - else: - argsList[IR.Bool(False)] = "doTruncation" - if node.op in [AST.Operators.TANH, AST.Operators.SIGMOID, AST.Operators.SQRT, AST.Operators.RSQRT]: - # Since these class of fucntions can only handle input of 32 bitlength, we have to scale down - # inputs before calling them. - if final_sf > 32: - assert (final_sf > self.scaleFac), "The program scaling factor is invalid. Should be lesser than 32 if network has tan/sig/sqrt" - assert(final_sf - self.scaleFac == self.scaleFac) - progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr, node.op.name, expr1, final_sf - self.scaleFac)) - self.scaleFacMapping[expr1.idf] = self.scaleFac - final_sf = self.scaleFac - argsList[IR.Int(final_sf,32)] = "sA" - argsList[IR.Int(final_sf,32)] = "sB" - self.scaleFacMapping[out_expr.idf] = final_sf - - # Tanh/Sigmoid/Sqrt impl only supports upto 32 bitwidth for input - if node.op in [AST.Operators.TANH, AST.Operators.SIGMOID, AST.Operators.SQRT, AST.Operators.RSQRT]: - argsList[IR.Int(min(32, self.actualbitwidth), 32)] = "bwA" - argsList[IR.Int(self.actualbitwidth, 32)] = "bwB" - if node.op == AST.Operators.SQRT: - argsList[IR.Bool(False)] = "inverse" - if node.op == AST.Operators.RSQRT: - argsList[IR.Bool(True)] = "inverse" - argsList[IR.Int(8,32)] = "LUTBITS" - - comment = IR.Comment(str(node.metadata)) - funcNameSuffix = "" - if Type.isTensor(inputType): - funcNameSuffix = str(len(inputType.shape)) - - progFinal = IR.Prog([comment, IR.FuncCall(funcName + self.varNameDelim + funcNameSuffix, argsList)]) - if Type.isTensor(node.type): - progFinal = IRUtil.prog_merge(IR.Prog([IR.Decl(out_expr.idf, node.type)]), progFinal) - - progFinal = IRUtil.prog_merge(prog1, progExtraBefore, progFinal) - return (progFinal, out_expr) - - def visitLet(self, node:AST.Let, args=None): - (prog_1, expr_1) = self.visit(node.decl) - typ_1 = node.decl.type - idf = node.name.name - if (not(Type.isInt(typ_1))): - self.name_mapping[idf] = expr_1.idf - if (not(Util.Config.disableTruncOpti)): - self.scaleFacMapping[idf] = self.scaleFacMapping[expr_1.idf] - (prog_2, expr_2) = self.visit(node.expr) - prog_2 = prog_2.subst(idf, expr_1) - expr_2 = expr_2.subst(idf, expr_1) - prog_3 = IRUtil.prog_merge(prog_1, prog_2) - return (prog_3, expr_2) - - def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args=None): - progList = [] - exprList = [] - for ii, curArg in enumerate(node.argsList): - (progN, exprN) = self.visit(curArg) - progList.append(progN) - exprList.append(exprN) - - returnExpr = self.getTempVar() - - funcName = node.funcName - funcName += self.varNameDelim + str(len(node.outputShape)) - for ii, curArg in enumerate(node.argsList): - if Type.isTensor(curArg.type): - curShape = curArg.type.shape - - # If len(shape) == 0 : that means its a float - no need to qualify - # the function name with 0 in that case, since its essentially - # become an int. - if (len(curShape) > 0): - funcName += self.varNameDelim + str(len(curShape)) - ### TODO : right now if random strings like int are passed, its being set as datatype int -- int datatype in - # unintrepreted func call is being used in a hacky way right now - - # Policy : - # First output tensor sizes are inserted in args. - # Then for each input tensor, its shape is inserted in args, followed by the input tensor itself. - # If the current input tensor has the same shape as any of the previous tensors, then its shape is not inserted. - funcArgsList = OrderedDict() - - if not(Util.Config.disableTruncOpti): - #TODO -- remove CreateTensor from uninterp function calls - for ii, curArg in enumerate(node.argsList): - curExpr = exprList[ii] - curScale = self.scaleFacMapping[curExpr.idf] - curType = curArg.type - if (not(Type.isInt(curType))) and (curScale > self.scaleFac) and (curType.isSecret): - curProg = self.addTruncateFunctionCall(curArg, "UninterpFuncCall", curExpr, curScale - self.scaleFac) - progList.insert(0,curProg) - self.scaleFacMapping[curExpr.idf] = self.scaleFac - - self.scaleFacMapping[returnExpr.idf] = self.scaleFac - - tensorShapesFound = {} - outputShape = node.type.shape - for ii, curDim in enumerate(outputShape): - funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii) - tensorShapesFound[tuple(outputShape)] = True - for ii in range(0, len(node.argsList)): - if node.outputDiffInpDims < 2 and Type.isTensor(node.argsList[ii].type): - curInpShape = node.argsList[ii].type.shape - if ((node.outputDiffInpDims == 1) or (node.outputDiffInpDims == 0 and tuple(curInpShape) not in tensorShapesFound)): - for jj, curDim in enumerate(curInpShape): - funcArgsList[IR.Int(curDim, 32)] = "Input_" + str(ii) + self.varNameDelim + str(jj) - tensorShapesFound[tuple(curInpShape)] = True - funcArgsList[exprList[ii]] = "inpExpr_" + str(ii) - funcArgsList[returnExpr] = "output" - - comment = IR.Comment(str(node.metadata)) - progFinal = progList[0] - if len(progList) > 1: - for ii in range(1, len(progList)): - progFinal = IRUtil.prog_merge(progFinal, progList[ii]) - progFinal = IRUtil.prog_merge(progFinal, IR.Prog([comment, IR.FuncCall(funcName, funcArgsList)])) - - progFinal = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, - node.type, - isSecret=False if node.isSecret is False else "secret")]), - progFinal) - return (progFinal, returnExpr) - - def visitArgMax(self, node:AST.ArgMax, args=None): - (prog_1, expr1) = self.visit(node.expr) - (prog_2, expr2) = self.visit(node.dim) - - tmpExpr = self.getTempVar() - outputShape = node.type.shape - - funcArgsList = OrderedDict() - outputShape = node.type.shape - for ii, curDim in enumerate(outputShape): - funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii) - for ii, curDim in enumerate(node.inShape): - funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii) - funcArgsList[expr1] = "inArr" - funcArgsList[expr2] = "dim" - funcArgsList[tmpExpr] = "outArr" - - if not(Util.Config.disableTruncOpti): - self.scaleFacMapping[tmpExpr.idf] = -1 - - funcCall = IR.FuncCall("ArgMax" + self.varNameDelim + str(len(outputShape)), funcArgsList) - comment = IR.Comment(str(node.metadata)) - prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, funcCall])) - prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(tmpExpr.idf, node.type)]), prog_3) - return (prog_3, tmpExpr) - - def visitInput(self, node:AST.Input, args=None): - returnExpr = self.getTempVar() - returnExpr.inputVar = True - comment = IR.Comment(str(node.metadata)) - if not(Util.Config.disableTruncOpti): - self.scaleFacMapping[returnExpr.idf] = self.scaleFac - return (IR.Prog([comment, IR.Input(returnExpr, node.shape, node.dataType, node.isSecret, node.inputByParty)]), returnExpr) - - def visitReduce(self, node:AST.Reduce, args=None): - (prog_1, expr1) = self.visit(node.expr) - assert(node.op in [AST.Operators.ADD, AST.Operators.Mean]) - - # We already have the output shape so we dont need to calculate with keep_dims - - ''' + """ + + typ_1 = node.expr.type + typ_2 = node.type + + # Declare variables + expr_2 = self.getTempVar() + iters_1 = self.getTempIterators(typ_1.dim) + iters_2 = self.getTempIterators(typ_2.dim) + + # Initialize to 0 + cmd1 = [IR.Assn(var, IRUtil.zero) for var in iters_1] + + # Incrementing the first index + first_iter = iters_1[0] + cmd4 = IRUtil.incCmd(first_iter) + + # Incrementing other indices using a loop + cmd5 = [cmd4] + for i in range(1, typ_1.dim): + curr_iter = iters_1[i] + curr_size = IR.Int(typ_1.shape[i]) + cmd5 = [ + IRUtil.incCmd(curr_iter), + IR.If( + IRUtil.eq(curr_iter, curr_size), + [IRUtil.initVarToZero(curr_iter)] + cmd5, + ), + ] + + # Outer loop + # The iterators are selected based on the selection order specified by the user + loopShape = [] + loopIters = [] + + if node.order: + for order in node.order: + order = order - 1 + loopShape.append(typ_2.shape[order]) + loopIters.append(iters_2[order]) + else: + loopShape = typ_2.shape + loopIters = iters_2 + + loop2 = IRUtil.loop( + loopShape, + loopIters, + [ + IR.Assn( + IRUtil.addIndex(expr_2, iters_2), IRUtil.addIndex(expr_1, iters_1) + ) + ] + + cmd5, + ) + + # Finalize + comment1 = IR.Comment(str(node.metadata)) + comment2 = IR.Comment( + "reshape(" + + expr_1.idf + + ", " + + ", ".join(str(e) for e in typ_2.shape) + + ")" + ) + reshape_prog = IR.Prog([comment1, comment2] + cmd1 + loop2) + prog_2 = IRUtil.prog_merge(prog_1, reshape_prog) + + for var in iters_1: + prog_2 = IRUtil.prog_merge( + IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), prog_2 + ) + for var in iters_2: + prog_2 = IRUtil.prog_merge( + IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), prog_2 + ) + prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, typ_2)]), prog_2) + + if not (Util.Config.disableTruncOpti): + self.scaleFacMapping[expr_2.idf] = self.scaleFacMapping[expr_1.idf] + + return (prog_2, expr_2) + + def visitPool(self, node: AST.Pool, args=None): + (prog_1, expr_1) = self.visit(node.expr) + + [N, H, W, CI] = node.expr.type.shape + [N1, outH, outW, CI1] = node.type.shape + assert N == N1 and CI == CI1 + [expr_2] = self.getTempVars(1) + + comment = IR.Comment(str(node.metadata)) + funcCallArgsDict = OrderedDict() + funcCallArgsDict[IR.Int(N1, 32)] = "N1" + funcCallArgsDict[IR.Int(outH, 32)] = "outH" + funcCallArgsDict[IR.Int(outW, 32)] = "outW" + funcCallArgsDict[IR.Int(CI1, 32)] = "CI1" + funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.FH], 32)] = "FH" + funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.FW], 32)] = "FW" + funcCallArgsDict[ + IR.Int(node.options[AST.PaddingKeysDict.zPadHLeft], 32) + ] = "zPadHLeft" + funcCallArgsDict[ + IR.Int(node.options[AST.PaddingKeysDict.zPadHRight], 32) + ] = "zPadHRight" + funcCallArgsDict[ + IR.Int(node.options[AST.PaddingKeysDict.zPadWLeft], 32) + ] = "zPadWLeft" + funcCallArgsDict[ + IR.Int(node.options[AST.PaddingKeysDict.zPadWRight], 32) + ] = "zPadWRight" + funcCallArgsDict[ + IR.Int(node.options[AST.PaddingKeysDict.strideH], 32) + ] = "strideH" + funcCallArgsDict[ + IR.Int(node.options[AST.PaddingKeysDict.strideW], 32) + ] = "strideW" + funcCallArgsDict[IR.Int(N, 32)] = "N" + funcCallArgsDict[IR.Int(H, 32)] = "H" + funcCallArgsDict[IR.Int(W, 32)] = "W" + funcCallArgsDict[IR.Int(CI, 32)] = "CI" + + funcCallArgsDict[expr_1] = "input" + funcCallArgsDict[expr_2] = "output" + + funcCall = IR.FuncCall(node.poolType, funcCallArgsDict) + prog_pool = IR.Prog([comment, funcCall]) + prog_2 = IRUtil.prog_merge(prog_1, prog_pool) + prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, node.type)]), prog_2) + + if not (Util.Config.disableTruncOpti): + self.scaleFacMapping[expr_2.idf] = self.scaleFacMapping[expr_1.idf] + + return (prog_2, expr_2) + + def visitUOp(self, node: AST.UOp, args=None): + (prog_1, expr_1) = self.visit(node.expr) + op = node.op + if op == AST.Operators.ADD: + return (prog_1, expr_1) + assert op == AST.Operators.SUB + + typ_2 = node.type + expr_2 = self.getTempVar() + + if Type.isInt(typ_2): + comment = IR.Comment(str(node.metadata)) + bitlen = node.expr.bitlen + decl = IR.Decl(expr_2.idf, node.type, typ_2.bitlen, typ_2.isSecret) + assign = IR.Assn(expr_2, IRUtil.negate(expr_1)) + prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment, decl, assign])) + else: + # decl fresh vars + iters = self.getTempIterators(typ_2.dim) + + # cmdl_assn + expr_1_elt = IRUtil.addIndex(expr_1, iters) + expr_2_elt = IRUtil.addIndex(expr_2, iters) + cmdl_assn = IRUtil.loop( + typ_2.shape, + iters, + [IR.Assn(expr_2_elt, IRUtil.sub(IRUtil.zero, expr_1_elt))], + ) + comment = IR.Comment(str(node.metadata)) + prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment] + cmdl_assn)) + prog_2 = IRUtil.prog_merge( + IR.Prog([IR.Decl(expr_2.idf, node.type)]), prog_2 + ) + + if not (Util.Config.disableTruncOpti): + self.scaleFacMapping[expr_2.idf] = self.scaleFacMapping[expr_1.idf] + + return (prog_2, expr_2) + + def visitBOp(self, node: AST.BOp, args=None): + op = node.op + if op in [AST.Operators.ADD, AST.Operators.SUB, AST.Operators.Equal]: + return self.visitBopAddOrSubLike(node) + elif op in [AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv]: + return self.visitBopElemWiseOp(node) + elif op == AST.Operators.MUL: + return self.visitBopMul(node) + elif op == AST.Operators.CONV: + return self.visitBopConv(node) + elif op == AST.Operators.CONVTRANSPOSE: + return self.visitBopConvTranspose(node) + else: + assert False + + def visitBopAddOrSubLike(self, node: AST.BOp, args=None): + (prog_1, expr_1) = self.visit(node.expr1) + (prog_2, expr_2) = self.visit(node.expr2) + + op = node.op + if op == AST.Operators.ADD: + op_ir = IR.Op.Op["+"] + elif op == AST.Operators.SUB: + op_ir = IR.Op.Op["-"] + elif op == AST.Operators.Equal: + op_ir = IR.Op.Op["=="] + else: + assert False + + node_type = node.type + out_arr = self.getTempVar() + cmd0 = IR.Comment(expr_1.idf + " " + op_ir.name + " " + expr_2.idf) + comment = IR.Comment(str(node.metadata)) + + if not (Util.Config.disableTruncOpti): + expr1_sf = self.scaleFacMapping[expr_1.idf] + expr2_sf = self.scaleFacMapping[expr_2.idf] + scaleUpFactor = -1 + if expr1_sf > expr2_sf: + exprToScale = expr_2 + typeOfExprToScale = node.expr2.type + scaleUpFactor = expr1_sf - expr2_sf + self.scaleFacMapping[expr_2.idf] = expr1_sf + elif expr2_sf > expr1_sf: + exprToScale = expr_1 + typeOfExprToScale = node.expr1.type + scaleUpFactor = expr2_sf - expr1_sf + self.scaleFacMapping[expr_1.idf] = expr2_sf + + if scaleUpFactor != -1: + comm = IR.Comment( + "Scale up of args needed was found while doing OptimizeTruncations." + ) + argsDict = OrderedDict() + curFuncName = "ScaleUp" + if not (Type.isInt(typeOfExprToScale)): + outputShape = typeOfExprToScale.shape + for ii, curDimSize in enumerate(outputShape): + argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) + curFuncName += str(len(outputShape)) + argsDict[exprToScale] = "exprToScale, arg#{0}".format( + 2 if (expr1_sf > expr2_sf) else 1 + ) + argsDict[IR.Int(scaleUpFactor, 32)] = "ScaleUpFactor" + funcCall = IR.FuncCall(curFuncName, argsDict) + + if Type.isInt(typeOfExprToScale) or typeOfExprToScale.shape == []: + assn_expr = IR.Assn(exprToScale, funcCall) + curProg = IR.Prog([comm, assn_expr]) + else: + curProg = IR.Prog([comm, funcCall]) + prog_1 = IRUtil.prog_merge(curProg, prog_1) + + self.scaleFacMapping[out_arr.idf] = self.scaleFacMapping[expr_1.idf] + + decl = IR.Decl(out_arr.idf, node_type, node_type.bitlen, node_type.isSecret) + + if Type.isInt(node_type): + assign = IR.Assn(out_arr, IR.IntBop(expr_1, op_ir, expr_2)) + out_prog = IR.Prog([assign]) + else: + outputShape = node_type.shape + inp1_shape = [] if Type.isInt(node.expr1.type) else node.expr1.type.shape + inp2_shape = [] if Type.isInt(node.expr2.type) else node.expr2.type.shape + + expected_output_shape, _, _ = Util.getBroadcastShapes( + inp1_shape, inp2_shape + ) + assert outputShape == expected_output_shape + out_prog = IRUtil.generateBroadcastLoopBOp( + expr_1, inp1_shape, expr_2, inp2_shape, out_arr, op_ir + ) + + out_prog = IRUtil.prog_merge(IR.Prog([comment, cmd0, decl]), out_prog) + out_prog = IRUtil.prog_merge(prog_1, prog_2, out_prog) + return (out_prog, out_arr) + + # We first reshape both inputs and flatten them into 1d dims. + # For simplicity consider a non-broadcast example: + # inputs : inp1_arr[s1][s2], inp2_arr[s1][s2] + # after flattening : inp1_arr_flat[s1*s2], inp2_arr_flat[s1*s2] + # for i1=[0:s1] + # for i2=[0:s2] + # idx = i1*s2 + i2 + # inp1_arr_flat[idx] = inp1_arr[i1][i2] + # inp2_arr_flat[idx] = inp2_arr[i1][i2] + # If one input is from server and the other from model we can call an optimized version of mul + # ElemWiseActModelVectorMult(s1*s2, inp1_arr_flat, inp2_arr_flat, out_arr_flat) <- optimized + # OR + # ElemWiseSecretSharedVectorMult(s1*s2, inp1_arr_flat, inp2_arr_flat, out_arr_flat) + # Finally we reshape the flattened output + # for i1=[0:s1] + # for i2=[0:s2] + # idx = i1*s2 + i2 + # out_arr[i1][i2] = out_arr_flat[idx] + # Standard broadcast rules apply to generate these flattened tensors. + def visitBopElemWiseOp(self, node: AST.BOp, args=None): + (prog_1, expr_1) = self.visit(node.expr1) + (prog_2, expr_2) = self.visit(node.expr2) + + if node.op == AST.Operators.ElemWiseMul: + op_ir = IR.Op.Op[".*"] + funcName = "ElemWiseMul" + elif node.op == AST.Operators.ElemWiseDiv: + op_ir = IR.Op.Op["./"] + funcName = "ElemWiseDiv" + assert False, "Did not implement div yet" + else: + assert False, "Non mul/div elemwise op" + + comment = IR.Comment(str(node.metadata)) + cmd0 = IR.Comment(expr_1.idf + " " + op_ir.name + " " + expr_2.idf) + + node_type = node.type + # outArr[s1][s2] + out_arr = self.getTempVar() + decl_out_arr = IR.Decl( + out_arr.idf, node_type, node_type.bitlen, node_type.isSecret + ) + + if Type.isInt(node_type): + assign = IR.Assn(out_arr, IR.IntBop(expr_1, op_ir, expr_2)) + out_prog = IR.Prog([assign]) + else: + # Flattening inputs + output_shape = node_type.shape + inp1_shape = [] if Type.isInt(node.expr1.type) else node.expr1.type.shape + inp2_shape = [] if Type.isInt(node.expr2.type) else node.expr2.type.shape + out_iters = self.getTempIterators(len(output_shape)) + ( + expected_output_shape, + broadcast_mask_1, + broadcast_mask_2, + ) = Util.getBroadcastShapes(inp1_shape, inp2_shape) + assert expected_output_shape == output_shape + + # inp1_arr[i1][i2], inp2_arr[i1][i2], out_arr[i1][i2] + inp1_iters = IRUtil.getMaskedIters(broadcast_mask_1, out_iters, inp1_shape) + inp2_iters = IRUtil.getMaskedIters(broadcast_mask_2, out_iters, inp2_shape) + inp1_arr_expr = IRUtil.addIndex(expr_1, inp1_iters) + inp2_arr_expr = IRUtil.addIndex(expr_2, inp2_iters) + out_arr_expr = IRUtil.addIndex(out_arr, out_iters) + + flat_size = Util.get_volume(output_shape) + inp1_arr_flat = self.getTempVar() + inp2_arr_flat = self.getTempVar() + out_arr_flat = self.getTempVar() + flat_type = Type.Tensor( + [flat_size], + node.expr1.type.bitlen, + node.expr1.type.isSecret, + node.expr1.type.taint, + ) + # inp1_arr_flat[s1*s2] + # inp2_arr_flat[s1*s2] + # out_arr_flat[s1*s2] + decl_inp1_arr_flat = IR.Decl( + inp1_arr_flat.idf, + flat_type, + node.expr1.type.bitlen, + node.expr1.type.isSecret, + ) + decl_inp2_arr_flat = IR.Decl( + inp2_arr_flat.idf, + flat_type, + node.expr2.type.bitlen, + node.expr2.type.isSecret, + ) + decl_out_arr_flat = IR.Decl( + out_arr_flat.idf, flat_type, node.type.bitlen, node.type.isSecret + ) + # idx + flat_idx = self.getTempVar() + decl_flat_idx = IR.Decl( + flat_idx.idf, Type.Int(bitlen=32), bitlen=32, isSecret=False + ) + # For 4d, generate (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + (i4); + flat_idx_expr = IR.Int(0, 32) + for i in range(len(out_iters)): + vol = Util.get_volume(output_shape[i + 1 :]) + flat_idx_expr = IRUtil.add( + flat_idx_expr, IRUtil.mul(out_iters[i], IR.Int(vol, 32)) + ) + # inp1_arr_flat[idx], inp2_arr_flat[idx], out_arr_flat[idx] + inp1_arr_flat_expr = IRUtil.addIndex(inp1_arr_flat, [flat_idx]) + inp2_arr_flat_expr = IRUtil.addIndex(inp2_arr_flat, [flat_idx]) + out_arr_flat_expr = IRUtil.addIndex(out_arr_flat, [flat_idx]) + # idx = i1*s2 + i2; + # inp1_arr_flat[idx] = inp1_arr[i1][i2] + # inp2_arr_flat[idx] = inp2_arr[i1][i2] + assign_flat_idx_expr = IR.Assn(flat_idx, flat_idx_expr) + assign_inp1_arr_flat = IR.Assn(inp1_arr_flat_expr, inp1_arr_expr) + assign_inp2_arr_flat = IR.Assn(inp2_arr_flat_expr, inp2_arr_expr) + # Flattening loop + # for i1=[0:s1] + # for i2=[0:s2] + # idx = i1*s2 + i2 + # inp1_arr_flat[idx] = inp1_arr[i1][i2] + # inp2_arr_flat[idx] = inp2_arr[i1][i2] + out_loop = IRUtil.loop( + output_shape, + out_iters, + [assign_flat_idx_expr, assign_inp1_arr_flat, assign_inp2_arr_flat], + ) + out_prog = IRUtil.Prog(out_loop) + decls = [ + decl_out_arr, + decl_inp1_arr_flat, + decl_inp2_arr_flat, + decl_out_arr_flat, + decl_flat_idx, + ] + out_prog = IRUtil.prog_merge(IRUtil.Prog(decls), out_prog) + + # Insert call to mul/div functionality + argsDict = OrderedDict() + argsDict[IR.Int(flat_size, 32)] = "input_shape" + if node.op == AST.Operators.ElemWiseDiv: + argsDict[inp1_arr_flat] = "A" + argsDict[inp2_arr_flat] = "B" + funcName = "ElemwiseSuperDuperSecretDiv" + assert False, "Elemwise div not implemented" + else: + # If either input is a model weight we can use an optimised version for mul + # Otherwise if both are derived from client input we use the hadmaard version + isMulOptimised = False + if not (self.isModel(node.expr1)) and not (self.isModel(node.expr2)): + argsDict[inp1_arr_flat] = "A" + argsDict[inp2_arr_flat] = "B" + else: + isMulOptimised = True + # Optimised version expects the second parameter to be an input from server + if self.isModel(node.expr2): + argsDict[inp1_arr_flat] = "A" + argsDict[inp2_arr_flat] = "B" + else: + # Shuffle the params. + argsDict[inp2_arr_flat] = "A" + argsDict[inp1_arr_flat] = "B" + funcName = ( + "ElemWiseActModelVectorMult" + if isMulOptimised + else "ElemWiseSecretSharedVectorMult" + ) + argsDict[out_arr_flat] = "Output" + funcCall = IR.FuncCall(funcName, argsDict) + out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) + + # Clear temp arrays + argsDict = OrderedDict() + argsDict[IR.Int(flat_size, 32)] = "size" + argsDict[inp1_arr_flat] = "A" + funcCall = IR.FuncCall("ClearMemSecret1", argsDict) + out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) + argsDict = OrderedDict() + argsDict[IR.Int(flat_size, 32)] = "size" + argsDict[inp2_arr_flat] = "A" + funcCall = IR.FuncCall("ClearMemSecret1", argsDict) + out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) + + # Unflatten output + assign_out_arr_flat = IR.Assn(out_arr_expr, out_arr_flat_expr) + out_loop = IRUtil.loop( + output_shape, out_iters, [assign_flat_idx_expr, assign_out_arr_flat] + ) + out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog(out_loop)) + + argsDict = OrderedDict() + argsDict[IR.Int(flat_size, 32)] = "size" + argsDict[out_arr_flat] = "A" + funcCall = IR.FuncCall("ClearMemSecret1", argsDict) + out_prog = IRUtil.prog_merge(out_prog, IRUtil.Prog([funcCall])) + + progExtraBefore = IR.Prog([]) + progExtraAfter = IR.Prog([]) + if Util.Config.disableTruncOpti: + progExtraAfter = self.addTruncateFunctionCall( + node, "ElemWiseMul", out_arr, Util.Config.consSF + ) + else: + inputs_same = expr_1.idf == expr_2.idf + expr1_sf = self.scaleFacMapping[expr_1.idf] + expr2_sf = self.scaleFacMapping[expr_2.idf] + if expr1_sf > self.scaleFac: + progExtraBefore = self.addTruncateFunctionCall( + node.expr1, "ElemWiseMul", expr_1, expr1_sf - self.scaleFac + ) + self.scaleFacMapping[expr_1.idf] = self.scaleFac + if (not inputs_same) and (expr2_sf > self.scaleFac): + progExtraBefore = IRUtil.prog_merge( + progExtraBefore, + self.addTruncateFunctionCall( + node.expr2, "ElemWiseMul", expr_2, expr2_sf - self.scaleFac + ), + ) + self.scaleFacMapping[expr_2.idf] = self.scaleFac + self.scaleFacMapping[out_arr.idf] = 2 * self.scaleFac + + out_prog = IRUtil.prog_merge( + IRUtil.Prog([comment, cmd0]), progExtraBefore, out_prog, progExtraAfter + ) + return (out_prog, out_arr) + + def visitBopMul(self, node: AST.BOp, args=None): + typ_1 = node.expr1.type + typ_2 = node.expr2.type + typ_3 = node.type + if Type.isInt(typ_3): + return self.visitBopMulInt(node) + elif typ_1.dim == 0 or Type.isInt(typ_1) or typ_2.dim == 0 or Type.isInt(typ_2): + return self.visitBopMulScalar1DTensor(node) + else: + return self.visitBopMul2DTensor(node) + + def visitBopMulInt(self, node: AST.BOp, args=None): + (prog_1, expr_1) = self.visit(node.expr1) + (prog_2, expr_2) = self.visit(node.expr2) + + expr_3 = self.getTempVar() + comment = IR.Comment(str(node.metadata)) + bitlen = node.expr.bitlen + decl = IR.Decl(expr_3.idf, node.type, node.type.bitlen, node.type.isSecret) + assign = IR.Assn(expr_3, IRUtil.mul(expr_1, expr_2)) + prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, decl, assign])) + + progExtraBefore = IR.Prog([]) + progExtraAfter = IR.Prog([]) + if Util.Config.disableTruncOpti: + progExtraAfter = self.addTruncateFunctionCall( + node, "MulInt", expr_3, Util.Config.consSF + ) + else: + inputs_same = expr_1.idf == expr_2.idf + expr1_sf = self.scaleFacMapping[expr_1.idf] + expr2_sf = self.scaleFacMapping[expr_2.idf] + if expr1_sf > self.scaleFac: + progExtraBefore = self.addTruncateFunctionCall( + node.expr1, "MulInt", expr_1, expr1_sf - self.scaleFac + ) + self.scaleFacMapping[expr_1.idf] = self.scaleFac + if (not inputs_same) and (expr2_sf > self.scaleFac): + progExtraBefore = IRUtil.prog_merge( + progExtraBefore, + self.addTruncateFunctionCall( + node.expr2, "MulInt", expr_2, expr2_sf - self.scaleFac + ), + ) + self.scaleFacMapping[expr_2.idf] = self.scaleFac + self.scaleFacMapping[expr_3.idf] = 2 * self.scaleFac + + prog_3 = IRUtil.prog_merge(progExtraBefore, prog_3, progExtraAfter) + return (prog_3, expr_3) + + def visitBopMulScalar1DTensor(self, node: AST.BOp, args=None): + (prog_1, expr_1) = self.visit(node.expr1) + (prog_2, expr_2) = self.visit(node.expr2) + + typ_1 = node.expr1.type + typ_2 = node.expr2.type + typ_3 = node.type + + isIntMult = False + if typ_1.dim == 0 or Type.isInt(typ_1): + a, b = expr_1, expr_2 + outputShape = typ_2.shape + isIntMult = Type.isInt(typ_1) + else: + a, b = expr_2, expr_1 + outputShape = typ_1.shape + isIntMult = Type.isInt(typ_2) + + # decl fresh vars + expr_3 = self.getTempVar() + cmd0 = IR.Comment(expr_1.idf + " * " + expr_2.idf) + funcCallArgsDict = OrderedDict() + for ii, curDimSize in enumerate(outputShape): + funcCallArgsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) + funcCallArgsDict[a] = "A" + funcCallArgsDict[b] = "B" + funcCallArgsDict[expr_3] = "C" + progExtraBefore = IR.Prog([]) + progExtraAfter = IR.Prog([]) + if Util.Config.disableTruncOpti: + progExtraAfter = self.addTruncateFunctionCall( + node, "ScalarMul", expr_3, Util.Config.consSF + ) + else: + inputs_same = expr_1.idf == expr_2.idf + expr1_sf = self.scaleFacMapping[expr_1.idf] + expr2_sf = self.scaleFacMapping[expr_2.idf] + if expr1_sf > self.scaleFac: + progExtraBefore = self.addTruncateFunctionCall( + node.expr1, "ScalarMul", expr_1, expr1_sf - self.scaleFac + ) + self.scaleFacMapping[expr_1.idf] = self.scaleFac + if (not inputs_same) and (expr2_sf > self.scaleFac): + progExtraBefore = IRUtil.prog_merge( + progExtraBefore, + self.addTruncateFunctionCall( + node.expr2, "ScalarMul", expr_2, expr2_sf - self.scaleFac + ), + ) + self.scaleFacMapping[expr_2.idf] = self.scaleFac + self.scaleFacMapping[expr_3.idf] = 2 * self.scaleFac + + funcCall = IR.FuncCall( + "ScalarMul" + self.varNameDelim + str(len(outputShape)), funcCallArgsDict + ) + prog_3 = IRUtil.prog_merge( + prog_1, prog_2, progExtraBefore, IR.Prog([cmd0, funcCall]) + ) + prog_3 = IRUtil.prog_merge( + IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3, progExtraAfter + ) + return (prog_3, expr_3) + + def visitBopMul2DTensor(self, node: AST.BOp, args=None): + (prog_1, expr_1) = self.visit(node.expr1) + (prog_2, expr_2) = self.visit(node.expr2) + + # decl fresh vars + expr_3 = self.getTempVar() + + typ_1 = node.expr1.type + typ_2 = node.expr2.type + typ_3 = node.type + + [I, J] = typ_1.shape + [J, K] = typ_2.shape + typ_mul = Type.Tensor([J]) + + shrT = Util.Config.consSF + + cmd0 = IR.Comment(expr_1.idf + " * " + expr_2.idf) + funcCallArgsDict = OrderedDict() + funcCallArgsDict[IR.Int(I, 32)] = "I" + funcCallArgsDict[IR.Int(J, 32)] = "J" + funcCallArgsDict[IR.Int(K, 32)] = "K" + funcCallArgsDict[expr_1] = "A" + funcCallArgsDict[expr_2] = "B" + funcCallArgsDict[expr_3] = "C" + + # Add an arg as to which arg out of A or B is a model weight + # This is ok, since Athos is right now tailored for neural network inference + # and in inference, in every linear layer, either of A or B will be a model weight. + # This is required because for some backends, knowing which of A or B is a model weight + # can make a difference in their performance. + + assert self.isModel(node.expr1) or self.isModel( + node.expr2 + ), "Expecting one of A or B to be an input by the server (model weight)." + modelIsA = True + if not self.isModel(node.expr1): + modelIsA = False + funcCallArgsDict[IR.Bool(modelIsA)] = "modelIsA" + + progExtraBefore = IR.Prog([]) + progExtraAfter = IR.Prog([]) + if Util.Config.disableTruncOpti: + progExtraAfter = self.addTruncateFunctionCall( + node, "MatMul2D", expr_3, Util.Config.consSF + ) + else: + inputs_same = expr_1.idf == expr_2.idf + expr1_sf = self.scaleFacMapping[expr_1.idf] + expr2_sf = self.scaleFacMapping[expr_2.idf] + if expr1_sf > self.scaleFac: + progExtraBefore = self.addTruncateFunctionCall( + node.expr1, "MatMul2D", expr_1, expr1_sf - self.scaleFac + ) + self.scaleFacMapping[expr_1.idf] = self.scaleFac + if (not inputs_same) and (expr2_sf > self.scaleFac): + progExtraBefore = IRUtil.prog_merge( + progExtraBefore, + self.addTruncateFunctionCall( + node.expr2, "MatMul2D", expr_2, expr2_sf - self.scaleFac + ), + ) + self.scaleFacMapping[expr_2.idf] = self.scaleFac + self.scaleFacMapping[expr_3.idf] = 2 * self.scaleFac + + funcCall = IR.FuncCall("MatMul2D", funcCallArgsDict) + comment = IR.Comment(str(node.metadata)) + prog_3 = IRUtil.prog_merge( + prog_1, prog_2, progExtraBefore, IR.Prog([comment, cmd0, funcCall]) + ) + prog_3 = IRUtil.prog_merge( + IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3, progExtraAfter + ) + + return (prog_3, expr_3) + + def visitBopConv(self, node: AST.BOp, args=None): + (prog1, expr_1) = self.visit(node.expr1) + (prog2, expr_2) = self.visit(node.expr2) + + convDim = 2 + if AST.PaddingKeysDict.ConvDim in node.options: + convDim = node.options[AST.PaddingKeysDict.ConvDim] + + if convDim == 2: + [N, H, W, CI] = node.expr1.type.shape + [FH, FW, CI1, CO] = node.expr2.type.shape + elif convDim == 3: + [N, D, H, W, CI] = node.expr1.type.shape + [FD, FH, FW, CI1, CO] = node.expr2.type.shape + else: + assert False + + returnExpr = self.getTempVar() + comment = IR.Comment( + expr_1.idf + " # " + expr_2.idf + ", convDim = " + str(convDim) + ) + funcCallArgsDict = OrderedDict() + funcCallArgsDict[IR.Int(N, 32)] = "N" + if convDim == 3: + funcCallArgsDict[IR.Int(D, 32)] = "D" + funcCallArgsDict[IR.Int(H, 32)] = "H" + funcCallArgsDict[IR.Int(W, 32)] = "W" + funcCallArgsDict[IR.Int(CI, 32)] = "CI" + if convDim == 3: + funcCallArgsDict[IR.Int(FD, 32)] = "FD" + funcCallArgsDict[IR.Int(FH, 32)] = "FH" + funcCallArgsDict[IR.Int(FW, 32)] = "FW" + funcCallArgsDict[IR.Int(CO, 32)] = "CO" + if convDim == 3: + funcCallArgsDict[ + IR.Int(node.options[AST.PaddingKeysDict.zPadDLeft], 32) + ] = "zPadDLeft" + funcCallArgsDict[ + IR.Int(node.options[AST.PaddingKeysDict.zPadDRight], 32) + ] = "zPadDRight" + funcCallArgsDict[ + IR.Int(node.options[AST.PaddingKeysDict.zPadHLeft], 32) + ] = "zPadHLeft" + funcCallArgsDict[ + IR.Int(node.options[AST.PaddingKeysDict.zPadHRight], 32) + ] = "zPadHRight" + funcCallArgsDict[ + IR.Int(node.options[AST.PaddingKeysDict.zPadWLeft], 32) + ] = "zPadWLeft" + funcCallArgsDict[ + IR.Int(node.options[AST.PaddingKeysDict.zPadWRight], 32) + ] = "zPadWRight" + if convDim == 3: + funcCallArgsDict[ + IR.Int(node.options[AST.PaddingKeysDict.strideD], 32) + ] = "strideD" + funcCallArgsDict[ + IR.Int(node.options[AST.PaddingKeysDict.strideH], 32) + ] = "strideH" + funcCallArgsDict[ + IR.Int(node.options[AST.PaddingKeysDict.strideW], 32) + ] = "strideW" + + isGroupConv = False + if AST.PaddingKeysDict.group in node.options.keys(): + funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.group], 32)] = "G" + isGroupConv = True + + funcCallArgsDict[expr_1] = "input" + funcCallArgsDict[expr_2] = "filter" + if convDim == 3: + funcCallArgsDict[IR.Int(Util.Config.consSF, 32)] = "consSF" + funcCallArgsDict[returnExpr] = "output" + + if convDim == 2: + funcCallName = "Conv2D" + else: + funcCallName = "Conv3D" + + if isGroupConv: + funcCallName += "Group" + + funcCallName += "Wrapper" + + funcCall = IR.FuncCall(funcCallName, funcCallArgsDict) + progConv = IR.Prog([comment, funcCall]) + + progExtraBefore = IR.Prog([]) + progExtraAfter = IR.Prog([]) + if Util.Config.disableTruncOpti: + progExtraAfter = self.addTruncateFunctionCall( + node, "Conv", returnExpr, Util.Config.consSF + ) + else: + inputs_same = expr_1.idf == expr_2.idf + expr1_sf = self.scaleFacMapping[expr_1.idf] + expr2_sf = self.scaleFacMapping[expr_2.idf] + if expr1_sf > self.scaleFac: + progExtraBefore = self.addTruncateFunctionCall( + node.expr1, "Conv", expr_1, expr1_sf - self.scaleFac + ) + self.scaleFacMapping[expr_1.idf] = self.scaleFac + if (not inputs_same) and (expr2_sf > self.scaleFac): + progExtraBefore = IRUtil.prog_merge( + progExtraBefore, + self.addTruncateFunctionCall( + node.expr2, "Conv", expr_2, expr2_sf - self.scaleFac + ), + ) + self.scaleFacMapping[expr_2.idf] = self.scaleFac + self.scaleFacMapping[returnExpr.idf] = 2 * self.scaleFac + + returnProg = IRUtil.prog_merge(prog1, prog2, progExtraBefore, progConv) + returnProg = IRUtil.prog_merge( + IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg, progExtraAfter + ) + return (returnProg, returnExpr) + + def visitBopConvTranspose(self, node: AST.BOp, args=None): + (prog1, expr_1) = self.visit(node.expr1) + (prog2, expr_2) = self.visit(node.expr2) + + convDim = 2 + if AST.PaddingKeysDict.ConvDim in node.options: + convDim = node.options[AST.PaddingKeysDict.ConvDim] + + if convDim == 2: + [N, H_prime, W_prime, CI1] = node.expr1.type.shape + [FH, FW, CO, CI] = node.expr2.type.shape + elif convDim == 3: + [N, D_prime, H_prime, W_prime, CI1] = node.expr1.type.shape + [FD, FH, FW, CO, CI] = node.expr2.type.shape + else: + assert False + assert CI1 == CI + + H = node.options[AST.PaddingKeysDict.outputImgH] # outputH + W = node.options[AST.PaddingKeysDict.outputImgW] # outputW + pad_h_total = ( + node.options[AST.PaddingKeysDict.zPadHLeft] + + node.options[AST.PaddingKeysDict.zPadHRight] + ) + pad_w_total = ( + node.options[AST.PaddingKeysDict.zPadWLeft] + + node.options[AST.PaddingKeysDict.zPadWRight] + ) + strideH = node.options[AST.PaddingKeysDict.strideH] + strideW = node.options[AST.PaddingKeysDict.strideW] + [ + pad_h_tr_total, + stride_h_tr, + h_prime_tilde, + ] = AST.Operators.findConvTransposePadding(H, H_prime, FH, pad_h_total, strideH) + [ + pad_w_tr_total, + stride_w_tr, + w_prime_tilde, + ] = AST.Operators.findConvTransposePadding(W, W_prime, FW, pad_w_total, strideW) + + [ + pad_h_tr_left, + pad_h_tr_right, + ] = AST.Operators.findLeftRightPaddingFromTotalPadding(pad_h_tr_total) + [ + pad_w_tr_left, + pad_w_tr_right, + ] = AST.Operators.findLeftRightPaddingFromTotalPadding(pad_w_tr_total) + + assert ( + AST.Operators.findConvOutputImgSize( + h_prime_tilde, pad_h_tr_total, FH, stride_h_tr + ) + == H + ) + assert ( + AST.Operators.findConvOutputImgSize( + w_prime_tilde, pad_w_tr_total, FW, stride_w_tr + ) + == W + ) + + if convDim == 3: + D = node.options[AST.PaddingKeysDict.outputImgD] # outputD + pad_d_total = ( + node.options[AST.PaddingKeysDict.zPadDLeft] + + node.options[AST.PaddingKeysDict.zPadDRight] + ) + strideD = node.options[AST.PaddingKeysDict.strideD] + [ + pad_d_tr_total, + stride_d_tr, + d_prime_tilde, + ] = AST.Operators.findConvTransposePadding( + D, D_prime, FD, pad_d_total, strideD + ) + [ + pad_d_tr_left, + pad_d_tr_right, + ] = AST.Operators.findLeftRightPaddingFromTotalPadding(pad_d_tr_total) + assert ( + AST.Operators.findConvOutputImgSize( + d_prime_tilde, pad_d_tr_total, FD, stride_d_tr + ) + == D + ) + + returnExpr = self.getTempVar() + comment = IR.Comment( + expr_1.idf + " #T " + expr_2.idf + ", convDim = " + str(convDim) + ) + funcCallArgsDict = OrderedDict() + funcCallArgsDict[IR.Int(N, 32)] = "N" + if convDim == 3: + funcCallArgsDict[IR.Int(D_prime, 32)] = "D_prime" + funcCallArgsDict[IR.Int(H_prime, 32)] = "H_prime" + funcCallArgsDict[IR.Int(W_prime, 32)] = "W_prime" + funcCallArgsDict[IR.Int(CI, 32)] = "CI" + if convDim == 3: + funcCallArgsDict[IR.Int(FD, 32)] = "FD" + funcCallArgsDict[IR.Int(FH, 32)] = "FH" + funcCallArgsDict[IR.Int(FW, 32)] = "FW" + funcCallArgsDict[IR.Int(CO, 32)] = "CO" + if convDim == 3: + funcCallArgsDict[IR.Int(D, 32)] = "D" + funcCallArgsDict[IR.Int(H, 32)] = "H" + funcCallArgsDict[IR.Int(W, 32)] = "W" + if convDim == 3: + funcCallArgsDict[IR.Int(pad_d_tr_left, 32)] = "pad_d_tr_left" + funcCallArgsDict[IR.Int(pad_d_tr_right, 32)] = "pad_d_tr_right" + funcCallArgsDict[IR.Int(pad_h_tr_left, 32)] = "pad_h_tr_left" + funcCallArgsDict[IR.Int(pad_h_tr_right, 32)] = "pad_h_tr_right" + funcCallArgsDict[IR.Int(pad_w_tr_left, 32)] = "pad_w_tr_left" + funcCallArgsDict[IR.Int(pad_w_tr_right, 32)] = "pad_w_tr_right" + if convDim == 3: + funcCallArgsDict[IR.Int(strideD, 32)] = "strideD" + funcCallArgsDict[IR.Int(strideH, 32)] = "strideH" + funcCallArgsDict[IR.Int(strideW, 32)] = "strideW" + + funcCallArgsDict[expr_1] = "input" + funcCallArgsDict[expr_2] = "filter" + if convDim == 3: + funcCallArgsDict[IR.Int(Util.Config.consSF, 32)] = "consSF" + funcCallArgsDict[returnExpr] = "output" + + if convDim == 2: + funcCallName = "ConvTranspose2D" + else: + funcCallName = "ConvTranspose3D" + funcCallName += "Wrapper" + funcCall = IR.FuncCall(funcCallName, funcCallArgsDict) + + progConv = IR.Prog([comment, funcCall]) + + progExtraBefore = IR.Prog([]) + progExtraAfter = IR.Prog([]) + if Util.Config.disableTruncOpti: + progExtraAfter = self.addTruncateFunctionCall( + node, "ConvTranspose", returnExpr, self.scaleFac + ) + else: + inputs_same = expr_1.idf == expr_2.idf + expr1_sf = self.scaleFacMapping[expr_1.idf] + expr2_sf = self.scaleFacMapping[expr_2.idf] + if expr1_sf > self.scaleFac: + progExtraBefore = self.addTruncateFunctionCall( + node.expr1, "ConvTranspose", expr_1, expr1_sf - self.scaleFac + ) + self.scaleFacMapping[expr_1.idf] = self.scaleFac + if (not inputs_same) and (expr2_sf > self.scaleFac): + progExtraBefore = IRUtil.prog_merge( + progExtraBefore, + self.addTruncateFunctionCall( + node.expr2, "ConvTranspose", expr_2, expr2_sf - self.scaleFac + ), + ) + self.scaleFacMapping[expr2.idf] = self.scaleFac + self.scaleFacMapping[returnExpr.idf] = 2 * self.scaleFac + + returnProg = IRUtil.prog_merge(prog1, prog2, progExtraBefore, progConv) + returnProg = IRUtil.prog_merge( + IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg, progExtraAfter + ) + return (returnProg, returnExpr) + + def visitFunc(self, node: AST.Func, args=None): + op = node.op + assert op in [ + AST.Operators.Floor, + AST.Operators.Shape, + AST.Operators.RELU, + AST.Operators.TANH, + AST.Operators.SIGMOID, + AST.Operators.SQRT, + AST.Operators.RSQRT, + AST.Operators.ClearMemSecret, + AST.Operators.ClearMemPublic, + ] + return self.visitFloorLike(node) + + def visitFloorLike(self, node: AST.Func, args=None): + (prog1, expr1) = self.visit(node.expr) + out_expr = self.getTempVar() + + if node.op == AST.Operators.Floor: + funcName = "Floor" + elif node.op == AST.Operators.Shape: + funcName = "Shape" + elif node.op == AST.Operators.RELU: + funcName = "Relu" + elif node.op == AST.Operators.TANH: + funcName = "Tanh" + elif node.op == AST.Operators.SIGMOID: + funcName = "Sigmoid" + elif node.op == AST.Operators.SQRT: + funcName = "Sqrt" + elif node.op == AST.Operators.RSQRT: + funcName = "Sqrt" + elif node.op == AST.Operators.ClearMemSecret: + funcName = "ClearMemSecret" + elif node.op == AST.Operators.ClearMemPublic: + funcName = "ClearMemPublic" + else: + assert False + + # We don't need to clear scalars. + if ( + node.op == AST.Operators.ClearMemSecret + or node.op == AST.Operators.ClearMemPublic + ): + if Type.isInt(node.expr.type): + return (prog1, expr1) + if node.expr.type.dim == 0: + return (prog1, expr1) + + argsList = OrderedDict() + + inputType = node.expr.type + if Type.isTensor(inputType): + for ii, curDim in enumerate(inputType.shape): + argsList[IR.Int(curDim, 32)] = "inShape_" + str(ii) + argsList[expr1] = "inArr" + + if Type.isTensor(node.type): + argsList[out_expr] = "outArr" + + if node.op == AST.Operators.Floor: + argsList[IR.Int(Util.Config.consSF, 32)] = "curScale" + + progExtraBefore = IR.Prog([]) + if Util.Config.disableTruncOpti: + if node.op == AST.Operators.RELU: + argsList[IR.Int(Util.Config.consSF, 32)] = "consSF" + argsList[IR.Bool(False)] = "doTruncation" + if node.op in [ + AST.Operators.TANH, + AST.Operators.SIGMOID, + AST.Operators.SQRT, + AST.Operators.RSQRT, + ]: + argsList[IR.Int(self.scaleFac, 32)] = "sA" + argsList[IR.Int(self.scaleFac, 32)] = "sB" + else: + final_sf = self.scaleFacMapping[expr1.idf] + if node.op == AST.Operators.RELU: + argsList[IR.Int(final_sf - self.scaleFac, 32)] = "consSF" + if final_sf > self.scaleFac: + # If it can't tolerate one more mult operation, then scale down here + assert final_sf - self.scaleFac == self.scaleFac + final_sf = self.scaleFac + argsList[IR.Bool(True)] = "doTruncation" + else: + argsList[IR.Bool(False)] = "doTruncation" + if node.op in [ + AST.Operators.TANH, + AST.Operators.SIGMOID, + AST.Operators.SQRT, + AST.Operators.RSQRT, + ]: + # Since these class of fucntions can only handle input of 32 bitlength, we have to scale down + # inputs before calling them. + if final_sf > 32: + assert ( + final_sf > self.scaleFac + ), "The program scaling factor is invalid. Should be lesser than 32 if network has tan/sig/sqrt" + assert final_sf - self.scaleFac == self.scaleFac + progExtraBefore = IRUtil.prog_merge( + progExtraBefore, + self.addTruncateFunctionCall( + node.expr, node.op.name, expr1, final_sf - self.scaleFac + ), + ) + self.scaleFacMapping[expr1.idf] = self.scaleFac + final_sf = self.scaleFac + argsList[IR.Int(final_sf, 32)] = "sA" + argsList[IR.Int(final_sf, 32)] = "sB" + self.scaleFacMapping[out_expr.idf] = final_sf + + # Tanh/Sigmoid/Sqrt impl only supports upto 32 bitwidth for input + if node.op in [ + AST.Operators.TANH, + AST.Operators.SIGMOID, + AST.Operators.SQRT, + AST.Operators.RSQRT, + ]: + argsList[IR.Int(min(32, self.actualbitwidth), 32)] = "bwA" + argsList[IR.Int(self.actualbitwidth, 32)] = "bwB" + if node.op == AST.Operators.SQRT: + argsList[IR.Bool(False)] = "inverse" + if node.op == AST.Operators.RSQRT: + argsList[IR.Bool(True)] = "inverse" + argsList[IR.Int(8, 32)] = "LUTBITS" + + comment = IR.Comment(str(node.metadata)) + funcNameSuffix = "" + if Type.isTensor(inputType): + funcNameSuffix = str(len(inputType.shape)) + + progFinal = IR.Prog( + [ + comment, + IR.FuncCall(funcName + self.varNameDelim + funcNameSuffix, argsList), + ] + ) + if Type.isTensor(node.type): + progFinal = IRUtil.prog_merge( + IR.Prog([IR.Decl(out_expr.idf, node.type)]), progFinal + ) + + progFinal = IRUtil.prog_merge(prog1, progExtraBefore, progFinal) + return (progFinal, out_expr) + + def visitLet(self, node: AST.Let, args=None): + (prog_1, expr_1) = self.visit(node.decl) + typ_1 = node.decl.type + idf = node.name.name + if not (Type.isInt(typ_1)): + self.name_mapping[idf] = expr_1.idf + if not (Util.Config.disableTruncOpti): + self.scaleFacMapping[idf] = self.scaleFacMapping[expr_1.idf] + (prog_2, expr_2) = self.visit(node.expr) + prog_2 = prog_2.subst(idf, expr_1) + expr_2 = expr_2.subst(idf, expr_1) + prog_3 = IRUtil.prog_merge(prog_1, prog_2) + return (prog_3, expr_2) + + def visitUninterpFuncCall(self, node: AST.UninterpFuncCall, args=None): + progList = [] + exprList = [] + for ii, curArg in enumerate(node.argsList): + (progN, exprN) = self.visit(curArg) + progList.append(progN) + exprList.append(exprN) + + returnExpr = self.getTempVar() + + funcName = node.funcName + funcName += self.varNameDelim + str(len(node.outputShape)) + for ii, curArg in enumerate(node.argsList): + if Type.isTensor(curArg.type): + curShape = curArg.type.shape + + # If len(shape) == 0 : that means its a float - no need to qualify + # the function name with 0 in that case, since its essentially + # become an int. + if len(curShape) > 0: + funcName += self.varNameDelim + str(len(curShape)) + ### TODO : right now if random strings like int are passed, its being set as datatype int -- int datatype in + # unintrepreted func call is being used in a hacky way right now + + # Policy : + # First output tensor sizes are inserted in args. + # Then for each input tensor, its shape is inserted in args, followed by the input tensor itself. + # If the current input tensor has the same shape as any of the previous tensors, then its shape is not inserted. + funcArgsList = OrderedDict() + + if not (Util.Config.disableTruncOpti): + # TODO -- remove CreateTensor from uninterp function calls + for ii, curArg in enumerate(node.argsList): + curExpr = exprList[ii] + curScale = self.scaleFacMapping[curExpr.idf] + curType = curArg.type + if ( + (not (Type.isInt(curType))) + and (curScale > self.scaleFac) + and (curType.isSecret) + ): + curProg = self.addTruncateFunctionCall( + curArg, "UninterpFuncCall", curExpr, curScale - self.scaleFac + ) + progList.insert(0, curProg) + self.scaleFacMapping[curExpr.idf] = self.scaleFac + + self.scaleFacMapping[returnExpr.idf] = self.scaleFac + + tensorShapesFound = {} + outputShape = node.type.shape + for ii, curDim in enumerate(outputShape): + funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii) + tensorShapesFound[tuple(outputShape)] = True + for ii in range(0, len(node.argsList)): + if node.outputDiffInpDims < 2 and Type.isTensor(node.argsList[ii].type): + curInpShape = node.argsList[ii].type.shape + if (node.outputDiffInpDims == 1) or ( + node.outputDiffInpDims == 0 + and tuple(curInpShape) not in tensorShapesFound + ): + for jj, curDim in enumerate(curInpShape): + funcArgsList[IR.Int(curDim, 32)] = ( + "Input_" + str(ii) + self.varNameDelim + str(jj) + ) + tensorShapesFound[tuple(curInpShape)] = True + funcArgsList[exprList[ii]] = "inpExpr_" + str(ii) + funcArgsList[returnExpr] = "output" + + comment = IR.Comment(str(node.metadata)) + progFinal = progList[0] + if len(progList) > 1: + for ii in range(1, len(progList)): + progFinal = IRUtil.prog_merge(progFinal, progList[ii]) + progFinal = IRUtil.prog_merge( + progFinal, IR.Prog([comment, IR.FuncCall(funcName, funcArgsList)]) + ) + + progFinal = IRUtil.prog_merge( + IR.Prog( + [ + IR.Decl( + returnExpr.idf, + node.type, + isSecret=False if node.isSecret is False else "secret", + ) + ] + ), + progFinal, + ) + return (progFinal, returnExpr) + + def visitArgMax(self, node: AST.ArgMax, args=None): + (prog_1, expr1) = self.visit(node.expr) + (prog_2, expr2) = self.visit(node.dim) + + tmpExpr = self.getTempVar() + outputShape = node.type.shape + + funcArgsList = OrderedDict() + outputShape = node.type.shape + for ii, curDim in enumerate(outputShape): + funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii) + for ii, curDim in enumerate(node.inShape): + funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii) + funcArgsList[expr1] = "inArr" + funcArgsList[expr2] = "dim" + funcArgsList[tmpExpr] = "outArr" + + if not (Util.Config.disableTruncOpti): + self.scaleFacMapping[tmpExpr.idf] = -1 + + funcCall = IR.FuncCall( + "ArgMax" + self.varNameDelim + str(len(outputShape)), funcArgsList + ) + comment = IR.Comment(str(node.metadata)) + prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, funcCall])) + prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(tmpExpr.idf, node.type)]), prog_3) + return (prog_3, tmpExpr) + + def visitInput(self, node: AST.Input, args=None): + returnExpr = self.getTempVar() + returnExpr.inputVar = True + comment = IR.Comment(str(node.metadata)) + if not (Util.Config.disableTruncOpti): + self.scaleFacMapping[returnExpr.idf] = self.scaleFac + return ( + IR.Prog( + [ + comment, + IR.Input( + returnExpr, + node.shape, + node.dataType, + node.isSecret, + node.inputByParty, + ), + ] + ), + returnExpr, + ) + + def visitReduce(self, node: AST.Reduce, args=None): + (prog_1, expr1) = self.visit(node.expr) + assert node.op in [AST.Operators.ADD, AST.Operators.Mean] + + # We already have the output shape so we dont need to calculate with keep_dims + + """ We need to reduce across axes. Example: Say reduction axes are specified as 0,3 and keep dim = false output rank -> len(input_shape) - len(reduction_axes) @@ -1269,190 +1659,239 @@ def visitReduce(self, node:AST.Reduce, args=None): for i2=[0:s2] output[i1][i2] = out_flat[i1*s2 + i2] - ''' - reduced_dims = node.reductionAxesList - inputShape = node.expr.type.shape - perm = [] - calculated_shape = [] - inputiters = self.getTempIterators(node.expr.type.dim) - outputiters = [] - no_elems = 1 - j = 0 - - for i in range(len(inputShape)): - if i not in reduced_dims: - perm.append(i) - # perm will now be [ 1 ,2 ] + [ 0, 3] - perm.extend(reduced_dims) - print(perm) - print(reduced_dims) - loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))] - shuffled_inputiters = [inputiters[perm[i]] for i in range(len(inputShape))] - - for i in range(len(inputShape)): - if i not in reduced_dims: - calculated_shape.append(inputShape[i]) - outputiters.append(inputiters[j]) - j = j + 1 - else: - no_elems = no_elems * inputShape[i] - if node.keepdims == 1: - calculated_shape.append(1) - outputiters.append(IR.Int(0,32)) - - if calculated_shape == []: - calculated_shape = [1] - outputiters.append(IR.Int(0,32)) - - outputShape = node.type.shape - assert calculated_shape == outputShape, "calculate shape:{} - real_shape: {}".format(calculated_shape, outputShape) - - sumExpr = self.getTempVar() - sumExpr_decl = IR.Decl(sumExpr.idf, Type.Int()) - initSumCmd = IR.Assn(sumExpr, IRUtil.zero) - updateSumCmd = IR.Assn(sumExpr, IRUtil.add(sumExpr, IRUtil.addIndex(expr1, shuffled_inputiters))) - - if node.op == AST.Operators.Mean: - outer_nesting = len(inputShape) - len(reduced_dims) - temp_flat = self.getTempVar() - temp_flat_decl = IR.Decl(temp_flat.idf, - Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint), - isSecret=node.type.isSecret) - # i1*s2 + i2 - flat_idx_expr = IRUtil.getFlatArrIdxExpr(inputiters[:outer_nesting], loop_shape[:outer_nesting]) - # temp_flat[i1*s2 + i2] = sum - temp_flat_expr = IRUtil.addIndex(temp_flat, [flat_idx_expr]) - updateOutCmd = IR.Assn(temp_flat_expr, sumExpr) - elif node.op == AST.Operators.ADD: - output = self.getTempVar() - output_decl = IR.Decl(output.idf, node.type) - out_expr = IRUtil.addIndex(output, outputiters) - updateOutCmd = IR.Assn(out_expr, sumExpr) - - # Generate the sum loop - inner_loops_processed = 0 - sum_loop = [updateSumCmd] - for i in reversed(range(len(loop_shape))): - sum_loop = [IR.For(inputiters[i], 0, sum_loop, 0, endInt=loop_shape[i])] - inner_loops_processed+=1 - if(inner_loops_processed == len(reduced_dims)): - sum_loop = [initSumCmd] + sum_loop + [updateOutCmd] - - if node.op == AST.Operators.ADD: - comment = IR.Comment(str(node.metadata)) - final_prog = IRUtil.prog_merge( prog_1, - IR.Prog([comment]), - IR.Prog([sumExpr_decl, output_decl]), - IR.Prog(sum_loop)) - if not(Util.Config.disableTruncOpti): - self.scaleFacMapping[output.idf] = self.scaleFacMapping[expr1.idf] - - return (final_prog, output) - - # Insert call to ElemWiseVectorPublicDiv(size=s1*s2, inp=temp_flat, divisor=s0*s3, out=out_flat) - out_flat = self.getTempVar() - out_flat_decl = IR.Decl(out_flat.idf, - Type.Tensor([Util.get_volume(loop_shape[:outer_nesting])], node.type.bitlen, node.type.isSecret, node.type.taint), - isSecret=node.type.isSecret) - argsDict = OrderedDict() - argsDict[IR.Int(Util.get_volume(loop_shape[:outer_nesting]), 32)] = "size" - argsDict[temp_flat] = "input" - argsDict[IR.Int(Util.get_volume(loop_shape[outer_nesting:]), 32)] = "divisor" - argsDict[out_flat] = "output" - div_call = IR.FuncCall("ElemWiseVectorPublicDiv", argsDict) - - # Free temp_flat here - # Clear temp arrays - argsDict = OrderedDict() - argsDict[IR.Int(Util.get_volume(loop_shape[:outer_nesting]), 32)] = "size" - argsDict[temp_flat] = "A" - free_temp_flat_call = IR.FuncCall("ClearMemSecret1", argsDict) - - # Unflatten the output - output = self.getTempVar() - output_decl = IR.Decl(output.idf, node.type) - out_expr = IRUtil.addIndex(output, outputiters) - out_flat_expr = IRUtil.addIndex(out_flat, [flat_idx_expr]) - out_assn_expr = IR.Assn(out_expr, out_flat_expr) - unflatten_loop = IRUtil.loop(loop_shape[:outer_nesting], inputiters[:outer_nesting], [out_assn_expr]) - - # Free out_flat here - argsDict = OrderedDict() - argsDict[IR.Int(Util.get_volume(loop_shape[:outer_nesting]), 32)] = "size" - argsDict[out_flat] = "A" - free_out_flat_call = IR.FuncCall("ClearMemSecret1", argsDict) - - if not(Util.Config.disableTruncOpti): - self.scaleFacMapping[output.idf] = self.scaleFacMapping[expr1.idf] - - comment = IR.Comment(str(node.metadata)) - final_prog = IRUtil.prog_merge( prog_1, - IR.Prog([comment]), - IR.Prog([sumExpr_decl, temp_flat_decl, out_flat_decl, output_decl]), - IR.Prog(sum_loop), - IR.Prog([div_call]), - IR.Prog([free_temp_flat_call]), - IR.Prog(unflatten_loop), - IR.Prog([free_out_flat_call])) - - return (final_prog, output) - - def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args=None): - (prog1, expr1) = self.visit(node.expr) - (prog2, expr2) = self.visit(node.multExpr) - (prog3, expr3) = self.visit(node.addExpr) - - returnExpr = self.getTempVar() - - funcArgsList = OrderedDict() - for ii, elem in enumerate(node.type.shape): - funcArgsList[IR.Int(elem, 32)] = "elem_"+str(ii) - funcArgsList[expr1] = "expr" - funcArgsList[expr2] = "multExpr" - funcArgsList[expr3] = "addExpr" - - progExtraBefore = IR.Prog([]) - multExprScaleDownSf = self.scaleFac - addExprScaleUpSf = 0 - if not(Util.Config.disableTruncOpti): - #TruncOpti is on - multExprScaleDownSf = 0 - addExprScaleUpSf = 0 - - expr_sf = self.scaleFacMapping[expr1.idf] - multExpr_sf = self.scaleFacMapping[expr2.idf] - addExpr_sf = self.scaleFacMapping[expr3.idf] - if (expr_sf > self.scaleFac): - #Scale down needed - progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr, "FusedBatchNorm", expr1, expr_sf - self.scaleFac)) - self.scaleFacMapping[expr1.idf] = self.scaleFac - - if (multExpr_sf > self.scaleFac): - #Scale down needed - progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.multExpr, "FusedBatchNorm", expr2, multExpr_sf - self.scaleFac)) - self.scaleFacMapping[expr2.idf] = self.scaleFac - - final_sf = 2*self.scaleFac - assert(final_sf >= addExpr_sf) - if (final_sf > addExpr_sf): - addExprScaleUpSf = final_sf - addExpr_sf - self.scaleFacMapping[expr3.idf] += addExprScaleUpSf - - self.scaleFacMapping[returnExpr.idf] = final_sf - - funcArgsList[IR.Int(multExprScaleDownSf, 32)] = "multExprScaleDownSf" - funcArgsList[IR.Int(addExprScaleUpSf, 32)] = "addExprScaleUpSf" - funcArgsList[returnExpr] = "returnExpr" - - funcCallIR = IR.FuncCall("FusedBatchNorm" + self.varNameDelim - + str(len(node.type.shape)) + self.varNameDelim #one for output - + str(len(node.type.shape)) + self.varNameDelim #one for input - + str(len(node.multExpr.type.shape)) + self.varNameDelim - + str(len(node.addExpr.type.shape)), - funcArgsList) - - comment = IR.Comment(str(node.metadata)) - returnProg = IRUtil.prog_merge(prog1, prog2, prog3, progExtraBefore, IR.Prog([comment, funcCallIR])) - - returnProg = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg) - return (returnProg, returnExpr) + """ + reduced_dims = node.reductionAxesList + inputShape = node.expr.type.shape + perm = [] + calculated_shape = [] + inputiters = self.getTempIterators(node.expr.type.dim) + outputiters = [] + no_elems = 1 + j = 0 + + for i in range(len(inputShape)): + if i not in reduced_dims: + perm.append(i) + # perm will now be [ 1 ,2 ] + [ 0, 3] + perm.extend(reduced_dims) + print(perm) + print(reduced_dims) + loop_shape = [inputShape[perm[i]] for i in range(len(inputShape))] + shuffled_inputiters = [inputiters[perm[i]] for i in range(len(inputShape))] + + for i in range(len(inputShape)): + if i not in reduced_dims: + calculated_shape.append(inputShape[i]) + outputiters.append(inputiters[j]) + j = j + 1 + else: + no_elems = no_elems * inputShape[i] + if node.keepdims == 1: + calculated_shape.append(1) + outputiters.append(IR.Int(0, 32)) + + if calculated_shape == []: + calculated_shape = [1] + outputiters.append(IR.Int(0, 32)) + + outputShape = node.type.shape + assert ( + calculated_shape == outputShape + ), "calculate shape:{} - real_shape: {}".format(calculated_shape, outputShape) + + sumExpr = self.getTempVar() + sumExpr_decl = IR.Decl(sumExpr.idf, Type.Int()) + initSumCmd = IR.Assn(sumExpr, IRUtil.zero) + updateSumCmd = IR.Assn( + sumExpr, IRUtil.add(sumExpr, IRUtil.addIndex(expr1, shuffled_inputiters)) + ) + + if node.op == AST.Operators.Mean: + outer_nesting = len(inputShape) - len(reduced_dims) + temp_flat = self.getTempVar() + temp_flat_decl = IR.Decl( + temp_flat.idf, + Type.Tensor( + [Util.get_volume(loop_shape[:outer_nesting])], + node.type.bitlen, + node.type.isSecret, + node.type.taint, + ), + isSecret=node.type.isSecret, + ) + # i1*s2 + i2 + flat_idx_expr = IRUtil.getFlatArrIdxExpr( + inputiters[:outer_nesting], loop_shape[:outer_nesting] + ) + # temp_flat[i1*s2 + i2] = sum + temp_flat_expr = IRUtil.addIndex(temp_flat, [flat_idx_expr]) + updateOutCmd = IR.Assn(temp_flat_expr, sumExpr) + elif node.op == AST.Operators.ADD: + output = self.getTempVar() + output_decl = IR.Decl(output.idf, node.type) + out_expr = IRUtil.addIndex(output, outputiters) + updateOutCmd = IR.Assn(out_expr, sumExpr) + + # Generate the sum loop + inner_loops_processed = 0 + sum_loop = [updateSumCmd] + for i in reversed(range(len(loop_shape))): + sum_loop = [IR.For(inputiters[i], 0, sum_loop, 0, endInt=loop_shape[i])] + inner_loops_processed += 1 + if inner_loops_processed == len(reduced_dims): + sum_loop = [initSumCmd] + sum_loop + [updateOutCmd] + + if node.op == AST.Operators.ADD: + comment = IR.Comment(str(node.metadata)) + final_prog = IRUtil.prog_merge( + prog_1, + IR.Prog([comment]), + IR.Prog([sumExpr_decl, output_decl]), + IR.Prog(sum_loop), + ) + if not (Util.Config.disableTruncOpti): + self.scaleFacMapping[output.idf] = self.scaleFacMapping[expr1.idf] + + return (final_prog, output) + + # Insert call to ElemWiseVectorPublicDiv(size=s1*s2, inp=temp_flat, divisor=s0*s3, out=out_flat) + out_flat = self.getTempVar() + out_flat_decl = IR.Decl( + out_flat.idf, + Type.Tensor( + [Util.get_volume(loop_shape[:outer_nesting])], + node.type.bitlen, + node.type.isSecret, + node.type.taint, + ), + isSecret=node.type.isSecret, + ) + argsDict = OrderedDict() + argsDict[IR.Int(Util.get_volume(loop_shape[:outer_nesting]), 32)] = "size" + argsDict[temp_flat] = "input" + argsDict[IR.Int(Util.get_volume(loop_shape[outer_nesting:]), 32)] = "divisor" + argsDict[out_flat] = "output" + div_call = IR.FuncCall("ElemWiseVectorPublicDiv", argsDict) + + # Free temp_flat here + # Clear temp arrays + argsDict = OrderedDict() + argsDict[IR.Int(Util.get_volume(loop_shape[:outer_nesting]), 32)] = "size" + argsDict[temp_flat] = "A" + free_temp_flat_call = IR.FuncCall("ClearMemSecret1", argsDict) + + # Unflatten the output + output = self.getTempVar() + output_decl = IR.Decl(output.idf, node.type) + out_expr = IRUtil.addIndex(output, outputiters) + out_flat_expr = IRUtil.addIndex(out_flat, [flat_idx_expr]) + out_assn_expr = IR.Assn(out_expr, out_flat_expr) + unflatten_loop = IRUtil.loop( + loop_shape[:outer_nesting], inputiters[:outer_nesting], [out_assn_expr] + ) + + # Free out_flat here + argsDict = OrderedDict() + argsDict[IR.Int(Util.get_volume(loop_shape[:outer_nesting]), 32)] = "size" + argsDict[out_flat] = "A" + free_out_flat_call = IR.FuncCall("ClearMemSecret1", argsDict) + + if not (Util.Config.disableTruncOpti): + self.scaleFacMapping[output.idf] = self.scaleFacMapping[expr1.idf] + + comment = IR.Comment(str(node.metadata)) + final_prog = IRUtil.prog_merge( + prog_1, + IR.Prog([comment]), + IR.Prog([sumExpr_decl, temp_flat_decl, out_flat_decl, output_decl]), + IR.Prog(sum_loop), + IR.Prog([div_call]), + IR.Prog([free_temp_flat_call]), + IR.Prog(unflatten_loop), + IR.Prog([free_out_flat_call]), + ) + + return (final_prog, output) + + def visitFusedBatchNorm(self, node: AST.FusedBatchNorm, args=None): + (prog1, expr1) = self.visit(node.expr) + (prog2, expr2) = self.visit(node.multExpr) + (prog3, expr3) = self.visit(node.addExpr) + + returnExpr = self.getTempVar() + + funcArgsList = OrderedDict() + for ii, elem in enumerate(node.type.shape): + funcArgsList[IR.Int(elem, 32)] = "elem_" + str(ii) + funcArgsList[expr1] = "expr" + funcArgsList[expr2] = "multExpr" + funcArgsList[expr3] = "addExpr" + + progExtraBefore = IR.Prog([]) + multExprScaleDownSf = self.scaleFac + addExprScaleUpSf = 0 + if not (Util.Config.disableTruncOpti): + # TruncOpti is on + multExprScaleDownSf = 0 + addExprScaleUpSf = 0 + + expr_sf = self.scaleFacMapping[expr1.idf] + multExpr_sf = self.scaleFacMapping[expr2.idf] + addExpr_sf = self.scaleFacMapping[expr3.idf] + if expr_sf > self.scaleFac: + # Scale down needed + progExtraBefore = IRUtil.prog_merge( + progExtraBefore, + self.addTruncateFunctionCall( + node.expr, "FusedBatchNorm", expr1, expr_sf - self.scaleFac + ), + ) + self.scaleFacMapping[expr1.idf] = self.scaleFac + + if multExpr_sf > self.scaleFac: + # Scale down needed + progExtraBefore = IRUtil.prog_merge( + progExtraBefore, + self.addTruncateFunctionCall( + node.multExpr, + "FusedBatchNorm", + expr2, + multExpr_sf - self.scaleFac, + ), + ) + self.scaleFacMapping[expr2.idf] = self.scaleFac + + final_sf = 2 * self.scaleFac + assert final_sf >= addExpr_sf + if final_sf > addExpr_sf: + addExprScaleUpSf = final_sf - addExpr_sf + self.scaleFacMapping[expr3.idf] += addExprScaleUpSf + + self.scaleFacMapping[returnExpr.idf] = final_sf + + funcArgsList[IR.Int(multExprScaleDownSf, 32)] = "multExprScaleDownSf" + funcArgsList[IR.Int(addExprScaleUpSf, 32)] = "addExprScaleUpSf" + funcArgsList[returnExpr] = "returnExpr" + + funcCallIR = IR.FuncCall( + "FusedBatchNorm" + + self.varNameDelim + + str(len(node.type.shape)) + + self.varNameDelim # one for output + + str(len(node.type.shape)) + + self.varNameDelim # one for input + + str(len(node.multExpr.type.shape)) + + self.varNameDelim + + str(len(node.addExpr.type.shape)), + funcArgsList, + ) + + comment = IR.Comment(str(node.metadata)) + returnProg = IRUtil.prog_merge( + prog1, prog2, prog3, progExtraBefore, IR.Prog([comment, funcCallIR]) + ) + + returnProg = IRUtil.prog_merge( + IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg + ) + return (returnProg, returnExpr) diff --git a/Athos/SeeDot/IR/IRUtil.py b/Athos/SeeDot/IR/IRUtil.py index 1c590a71..47c71be0 100644 --- a/Athos/SeeDot/IR/IRUtil.py +++ b/Athos/SeeDot/IR/IRUtil.py @@ -1,4 +1,4 @@ -''' +""" Authors: Sridhar Gopinath, Nishant Kumar. @@ -20,175 +20,275 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import numpy as np from IR.IR import * from Util import * + def init(): - global zero, one, negone, negmax + global zero, one, negone, negmax + + zero = Int(0) + one = Int(1) + negone = Int(-1) + negmax = Int.negMax() + + +def add(e1: Expr, e2: Expr) -> Expr: + return IntBop(e1, Op.Op["+"], e2) + + +def sub(e1: Expr, e2: Expr) -> Expr: + return IntBop(e1, Op.Op["-"], e2) + + +def mul(e1: Expr, e2: Expr) -> Expr: + return IntBop(e1, Op.Op["*"], e2) + + +def div(e1: Expr, e2: Expr) -> Expr: + return IntBop(e1, Op.Op["/"], e2) + + +def inc(e: Expr) -> Expr: + return add(e, one) + + +def dec(e: Expr) -> Expr: + return sub(e, one) + + +def andd(e1: Expr, e2: Expr) -> Expr: + return BoolBop(e1, Op.Op["&&"], e2) + + +def orr(e1: Expr, e2: Expr) -> Expr: + return BoolBop(e1, Op.Op["||"], e2) + + +def eq(e1: Expr, e2: Expr) -> Expr: + return BoolCop(e1, Op.Op["=="], e2) + + +def neq(e1: Expr, e2: Expr) -> Expr: + return BoolCop(e1, Op.Op["!="], e2) + + +def lt(e1: Expr, e2: Expr) -> Expr: + return BoolCop(e1, Op.Op["<"], e2) + + +def lte(e1: Expr, e2: Expr) -> Expr: + return BoolCop(e1, Op.Op["<="], e2) + - zero = Int(0) - one = Int(1) - negone = Int(-1) - negmax = Int.negMax() +def gt(e1: Expr, e2: Expr) -> Expr: + return BoolCop(e1, Op.Op[">"], e2) -def add(e1:Expr, e2:Expr) -> Expr: return IntBop(e1, Op.Op['+'], e2) -def sub(e1:Expr, e2:Expr) -> Expr: return IntBop(e1, Op.Op['-'], e2) -def mul(e1:Expr, e2:Expr) -> Expr: return IntBop(e1, Op.Op['*'], e2) -def div(e1:Expr, e2:Expr) -> Expr: return IntBop(e1, Op.Op['/'], e2) -def inc(e:Expr) -> Expr: return add(e, one) -def dec(e:Expr) -> Expr: return sub(e, one) +def gte(e1: Expr, e2: Expr) -> Expr: + return BoolCop(e1, Op.Op[">="], e2) -def andd(e1:Expr, e2:Expr) -> Expr: return BoolBop(e1, Op.Op['&&'], e2) -def orr(e1:Expr, e2:Expr) -> Expr: return BoolBop(e1, Op.Op['||'], e2) -def eq(e1:Expr, e2:Expr) -> Expr: return BoolCop(e1, Op.Op['=='], e2) -def neq(e1:Expr, e2:Expr) -> Expr: return BoolCop(e1, Op.Op['!='], e2) -def lt(e1:Expr, e2:Expr) -> Expr: return BoolCop(e1, Op.Op['<'], e2) -def lte(e1:Expr, e2:Expr) -> Expr: return BoolCop(e1, Op.Op['<='], e2) -def gt(e1:Expr, e2:Expr) -> Expr: return BoolCop(e1, Op.Op['>'], e2) -def gte(e1:Expr, e2:Expr) -> Expr: return BoolCop(e1, Op.Op['>='], e2) +def bitAnd(e1: Expr, e2: Expr) -> Expr: + return IntBop(e1, Op.Op["&"], e2) -def bitAnd(e1:Expr, e2:Expr) -> Expr: return IntBop(e1, Op.Op['&'], e2) -def max(e1:Expr, e2:Expr) -> Expr: - return CExpr(BoolCop(e1, Op.Op['>'], e2), e1, e2) +def max(e1: Expr, e2: Expr) -> Expr: + return CExpr(BoolCop(e1, Op.Op[">"], e2), e1, e2) -def max_uint(e1:Expr, e2:Expr) -> Expr: - return CExpr(BoolCop(e1, Op.Op['>'], e2), e1, e2) -def max_sint(e1:Expr, e2:Expr) -> Expr: - return cond_zero(e1, cond_zero(e2, max_uint(e1, e2), e1), cond_zero(e2, e2, max_uint(e1, e2))) +def max_uint(e1: Expr, e2: Expr) -> Expr: + return CExpr(BoolCop(e1, Op.Op[">"], e2), e1, e2) -def negate(e:Expr) -> Expr: - return IntUop(Op.Op['-'], e) -def shl(e:Expr, n:int) -> Expr: - assert(n >= 0) - if n == 0: return e - return IntBop(e, Op.Op['<<'], Int(n)) +def max_sint(e1: Expr, e2: Expr) -> Expr: + return cond_zero( + e1, cond_zero(e2, max_uint(e1, e2), e1), cond_zero(e2, e2, max_uint(e1, e2)) + ) -def shrUint(e:Expr, n:int) -> Expr: - assert(n >= 0) - if(n == 0): return e - return IntBop(e, Op.Op['>>'], Int(n)) -def shr(e:Expr, n:int) -> Expr: - return shrDefault(e, n) +def negate(e: Expr) -> Expr: + return IntUop(Op.Op["-"], e) -def shrDefault(e:Expr, n:int) -> Expr: - assert(n >= 0) - if(n == 0): return e - return cond_zero(e, IntBop(e, Op.Op['>>'], Int(n)), IntBop(IntBop(IntBop(e, Op.Op['^'], negone), Op.Op['>>'], Int(n)), Op.Op['^'], negone)) -def shrVar(e:Expr, n:Var) -> Expr: - return cond_zero(e, IntBop(e, Op.Op['>>'], n), IntBop(IntBop(IntBop(e, Op.Op['^'], negone), Op.Op['>>'], n), Op.Op['^'], negone)) +def shl(e: Expr, n: int) -> Expr: + assert n >= 0 + if n == 0: + return e + return IntBop(e, Op.Op["<<"], Int(n)) -def castToInt(e:Expr): - return TypeCast(DataType.getIntStr(), e) -def castToFloat(e:Expr): - return TypeCast(DataType.getFloatStr(), e) +def shrUint(e: Expr, n: int) -> Expr: + assert n >= 0 + if n == 0: + return e + return IntBop(e, Op.Op[">>"], Int(n)) -def addIndex(var:Var, indices:list, prefix:bool=False) -> Var: - if prefix == False: - return Var(var.idf, var.idx + indices, var.inputVar) - else: - return Var(var.idf, indices + var.idx, var.inputVar) -def cond_zero(e:Expr, et:Expr, ef:Expr) -> Expr: - return CExpr(BoolCop(e, Op.Op['>'], zero), et, ef) +def shr(e: Expr, n: int) -> Expr: + return shrDefault(e, n) -def relu(e:Expr): return cond_zero(e, e, zero) -def loop_shr(lhs:Expr, rhs:Expr, shape:list, iters:list, n:int) -> CmdList: - lhs_elt = addIndex(lhs, iters) - rhs_elt = addIndex(rhs, iters) - return loop(shape, iters, [Assn(lhs_elt, shr(rhs_elt,n))]) +def shrDefault(e: Expr, n: int) -> Expr: + assert n >= 0 + if n == 0: + return e + return cond_zero( + e, + IntBop(e, Op.Op[">>"], Int(n)), + IntBop( + IntBop(IntBop(e, Op.Op["^"], negone), Op.Op[">>"], Int(n)), + Op.Op["^"], + negone, + ), + ) -def initVarToZero(e:Expr) -> Cmd: return Assn(e, Int(0)) -def incCmd(e:Var) -> Cmd: return Assn(e, inc(e)) -def decCmd(e:Var) -> Cmd: return Assn(e, dec(e)) +def shrVar(e: Expr, n: Var) -> Expr: + return cond_zero( + e, + IntBop(e, Op.Op[">>"], n), + IntBop( + IntBop(IntBop(e, Op.Op["^"], negone), Op.Op[">>"], n), Op.Op["^"], negone + ), + ) + + +def castToInt(e: Expr): + return TypeCast(DataType.getIntStr(), e) + + +def castToFloat(e: Expr): + return TypeCast(DataType.getFloatStr(), e) + + +def addIndex(var: Var, indices: list, prefix: bool = False) -> Var: + if prefix == False: + return Var(var.idf, var.idx + indices, var.inputVar) + else: + return Var(var.idf, indices + var.idx, var.inputVar) + + +def cond_zero(e: Expr, et: Expr, ef: Expr) -> Expr: + return CExpr(BoolCop(e, Op.Op[">"], zero), et, ef) + + +def relu(e: Expr): + return cond_zero(e, e, zero) + + +def loop_shr(lhs: Expr, rhs: Expr, shape: list, iters: list, n: int) -> CmdList: + lhs_elt = addIndex(lhs, iters) + rhs_elt = addIndex(rhs, iters) + return loop(shape, iters, [Assn(lhs_elt, shr(rhs_elt, n))]) + + +def initVarToZero(e: Expr) -> Cmd: + return Assn(e, Int(0)) + + +def incCmd(e: Var) -> Cmd: + return Assn(e, inc(e)) + + +def decCmd(e: Var) -> Cmd: + return Assn(e, dec(e)) + def prog_merge(*prog_l, resource=0): - cmd_l = flatten([prog.cmd_l for prog in prog_l]) - Res = 0 - for x in prog_l: - Res = Res + x.resource - return Prog(cmd_l, resource=Res) + cmd_l = flatten([prog.cmd_l for prog in prog_l]) + Res = 0 + for x in prog_l: + Res = Res + x.resource + return Prog(cmd_l, resource=Res) + # multiplexer -def add_idx_priv(var:Var, e:Expr, n:int, offset:int=0) -> Expr: - assert(n >= 1) - mask = 1 << (n - 1) - - # use if-else - if False: - # for n=3: - # if e & 100 == 0: - # if e & 010 == 0: - # if e & 001 == 0: var[000] - # else: var[001] - # else: - # if e & 001 == 0: var[010] - # else: var[011] - # else: ... - expr_cmp = eq(IntBop(e, Op.Op['&'], Int(mask)), zero) - if n == 1: - return CExpr(expr_cmp, - addIndex(var, [Int(offset + 0)]), - addIndex(var, [Int(offset + mask)])) - else: - return CExpr(expr_cmp, - add_idx_priv(var, e, n - 1, offset + 0), - add_idx_priv(var, e, n - 1, offset + mask)) - # use *, + - else: - # for n=2: - # (1-(e&10)>>1) * ((1-(e&01)>>0)*var[00] + ((e&01)>>0)*var[01]) + - # ( (e&10)>>1) * ((1-(e&01)>>0)*var[10] + ((e&01)>>0)*var[11]) - expr_1 = shrUint(IntBop(e, Op.Op['&'], Int(mask)), n - 1) - expr_0 = sub(one, expr_1) - if n == 1: - return add(mul(expr_0, addIndex(var, [Int(offset + 0)])), - mul(expr_1, addIndex(var, [Int(offset + mask)]))) - else: - return add(mul(expr_0, add_idx_priv(var, e, n - 1, offset + 0)), - mul(expr_1, add_idx_priv(var, e, n - 1, offset + mask))) +def add_idx_priv(var: Var, e: Expr, n: int, offset: int = 0) -> Expr: + assert n >= 1 + mask = 1 << (n - 1) + + # use if-else + if False: + # for n=3: + # if e & 100 == 0: + # if e & 010 == 0: + # if e & 001 == 0: var[000] + # else: var[001] + # else: + # if e & 001 == 0: var[010] + # else: var[011] + # else: ... + expr_cmp = eq(IntBop(e, Op.Op["&"], Int(mask)), zero) + if n == 1: + return CExpr( + expr_cmp, + addIndex(var, [Int(offset + 0)]), + addIndex(var, [Int(offset + mask)]), + ) + else: + return CExpr( + expr_cmp, + add_idx_priv(var, e, n - 1, offset + 0), + add_idx_priv(var, e, n - 1, offset + mask), + ) + # use *, + + else: + # for n=2: + # (1-(e&10)>>1) * ((1-(e&01)>>0)*var[00] + ((e&01)>>0)*var[01]) + + # ( (e&10)>>1) * ((1-(e&01)>>0)*var[10] + ((e&01)>>0)*var[11]) + expr_1 = shrUint(IntBop(e, Op.Op["&"], Int(mask)), n - 1) + expr_0 = sub(one, expr_1) + if n == 1: + return add( + mul(expr_0, addIndex(var, [Int(offset + 0)])), + mul(expr_1, addIndex(var, [Int(offset + mask)])), + ) + else: + return add( + mul(expr_0, add_idx_priv(var, e, n - 1, offset + 0)), + mul(expr_1, add_idx_priv(var, e, n - 1, offset + mask)), + ) + # iteration -def loop(shape:list, iters:list, cmdl_body:CmdList, factor=0) -> CmdList: - cmdl_for = cmdl_body - for i in reversed(range(len(shape))): - cmdl_for = [For(iters[i], 0, cmdl_for, factor, endInt=shape[i])] - return cmdl_for - -def print_loop(shape:list, iters:list, cmdl_body:CmdList, factor=0) -> CmdList: - cmdl_for = cmdl_body - for i in reversed(range(len(shape))): - cmdl_for = [For(iters[i], 0, lt(iters[i], Int(shape[i])), cmdl_for, factor), Print(Var('""'))] - return cmdl_for +def loop(shape: list, iters: list, cmdl_body: CmdList, factor=0) -> CmdList: + cmdl_for = cmdl_body + for i in reversed(range(len(shape))): + cmdl_for = [For(iters[i], 0, cmdl_for, factor, endInt=shape[i])] + return cmdl_for + + +def print_loop(shape: list, iters: list, cmdl_body: CmdList, factor=0) -> CmdList: + cmdl_for = cmdl_body + for i in reversed(range(len(shape))): + cmdl_for = [ + For(iters[i], 0, lt(iters[i], Int(shape[i])), cmdl_for, factor), + Print(Var('""')), + ] + return cmdl_for + # For tensor A of shape = 7 x 1 x 5 # And out_iters = [i0, i1, i2, i3] # Broadcast mask = [True, False, True, False] # We generate iters = A[i1][0][i3] # If input is scalar, broadcast_mask=[] and inp_shape=[] -def getMaskedIters(broadcast_mask: list, out_iters: list, inp_shape : list): - base_idx = len(out_iters) - len(inp_shape) - masked_iters = [] - for i in range(len(broadcast_mask)): - if broadcast_mask[i]: - masked_iters.append(Int(0,32)) - else: - masked_iters.append(out_iters[base_idx]) - base_idx +=1 - return masked_iters +def getMaskedIters(broadcast_mask: list, out_iters: list, inp_shape: list): + base_idx = len(out_iters) - len(inp_shape) + masked_iters = [] + for i in range(len(broadcast_mask)): + if broadcast_mask[i]: + masked_iters.append(Int(0, 32)) + else: + masked_iters.append(out_iters[base_idx]) + base_idx += 1 + return masked_iters + # Given input # A (4d array): 8 x 1 x 6 x 1 @@ -200,20 +300,25 @@ def getMaskedIters(broadcast_mask: list, out_iters: list, inp_shape : list): # for i2=[0:6] # for i3=[0:8] # Result[i0][i1][i2][i3] = A[i0][0][i2][0] + B[i1][0][i3] -def generateBroadcastLoopBOp(expr_1, inp1_shape: list, expr_2, inp2_shape : list, expr_out, op: Op.Op): - output_shape, broadcast_mask_1, broadcast_mask_2 = Util.getBroadcastShapes(inp1_shape, inp2_shape) - out_iters = [Var('i' + str(i)) for i in range(len(output_shape))] - inp1_iters = getMaskedIters(broadcast_mask_1, out_iters, inp1_shape) - inp2_iters = getMaskedIters(broadcast_mask_2, out_iters, inp2_shape) - - inp1_arr_expr = addIndex(expr_1, inp1_iters) - inp2_arr_expr = addIndex(expr_2, inp2_iters) - out_arr_expr = addIndex(expr_out, out_iters) +def generateBroadcastLoopBOp( + expr_1, inp1_shape: list, expr_2, inp2_shape: list, expr_out, op: Op.Op +): + output_shape, broadcast_mask_1, broadcast_mask_2 = Util.getBroadcastShapes( + inp1_shape, inp2_shape + ) + out_iters = [Var("i" + str(i)) for i in range(len(output_shape))] + inp1_iters = getMaskedIters(broadcast_mask_1, out_iters, inp1_shape) + inp2_iters = getMaskedIters(broadcast_mask_2, out_iters, inp2_shape) + + inp1_arr_expr = addIndex(expr_1, inp1_iters) + inp2_arr_expr = addIndex(expr_2, inp2_iters) + out_arr_expr = addIndex(expr_out, out_iters) + + assign_expr = Assn(out_arr_expr, IntBop(inp1_arr_expr, op, inp2_arr_expr)) + out_loop = loop(output_shape, out_iters, [assign_expr]) + out_prog = Prog(out_loop) + return out_prog - assign_expr = Assn(out_arr_expr, IntBop(inp1_arr_expr, op, inp2_arr_expr)) - out_loop = loop(output_shape, out_iters, [assign_expr]) - out_prog = Prog(out_loop) - return out_prog # Generates the index into a flattened tensor. # Example: @@ -222,10 +327,12 @@ def generateBroadcastLoopBOp(expr_1, inp1_shape: list, expr_2, inp2_shape : list # for i3=[0:s3] # for i4=[0:s4] # generate (i1*s2*s3*s4) + (i2*s3*s4) + (i3*s4) + (i4); -def getFlatArrIdxExpr(iters:list, shape:list): - assert len(iters) == len(shape), "No. of loop idx vars should be equal to loop shapes" - flat_idx_expr = Int(0,32) - for i in range(len(iters)): - vol = get_volume(shape[i+1:]) - flat_idx_expr = add(flat_idx_expr, mul(iters[i], Int(vol,32))) - return flat_idx_expr \ No newline at end of file +def getFlatArrIdxExpr(iters: list, shape: list): + assert len(iters) == len( + shape + ), "No. of loop idx vars should be equal to loop shapes" + flat_idx_expr = Int(0, 32) + for i in range(len(iters)): + vol = get_volume(shape[i + 1 :]) + flat_idx_expr = add(flat_idx_expr, mul(iters[i], Int(vol, 32))) + return flat_idx_expr diff --git a/Athos/SeeDot/Optimizations/GarbageCollector.py b/Athos/SeeDot/Optimizations/GarbageCollector.py index 27a2dcaf..1cefa810 100644 --- a/Athos/SeeDot/Optimizations/GarbageCollector.py +++ b/Athos/SeeDot/Optimizations/GarbageCollector.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import AST.AST as AST import Util @@ -29,112 +29,120 @@ class SecretFlowAnalysis(ASTVisitor): - def __init__(self): - self.idf_to_secret = {} - self.node_to_secret = {} - - def isSecret(self, idf:str): - return self.idf_to_secret[idf] - - def visitInt(self, node:AST.Int, args): - self.node_to_secret[node] = node.isSecret - - def visitFloat(self, node:AST.Float, args): - self.node_to_secret[node] = node.isSecret - - def visitInput(self, node:AST.Input, args): - self.node_to_secret[node] = node.isSecret - - def visitId(self, node:AST.ID, args): - self.node_to_secret[node] = self.idf_to_secret[node.name] - - def visitLet(self, node:AST.Let, args): - self.visit(node.decl, args) - self.idf_to_secret[node.name.name] = self.node_to_secret[node.decl] - self.visit(node.expr, args) - - def visitDecl(self, node:AST.Decl, args): - self.node_to_secret[node] = node.isSecret - if node.valueList: - for elem in node.valueList: - self.visit(elem, args) - - def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args): - self.node_to_secret[node] = node.isSecret - for elem in node.argsList: - self.visit(elem, args) - - def visitTranspose(self, node:AST.Transpose, args): - self.visit(node.expr, args) - self.node_to_secret[node] = self.node_to_secret[node.expr] - - def visitSlice(self, node:AST.Slice, args): - self.visit(node.expr, args) - self.node_to_secret[node] = self.node_to_secret[node.expr] - - def visitReshape(self, node:AST.Reshape, args): - self.visit(node.expr, args) - self.node_to_secret[node] = self.node_to_secret[node.expr] - - def visitPool(self, node:AST.Pool, args): - self.visit(node.expr, args) - self.node_to_secret[node] = self.node_to_secret[node.expr] - - def visitUOp(self, node:AST.UOp, args): - self.visit(node.expr, args) - self.node_to_secret[node] = self.node_to_secret[node.expr] - - def visitBOp(self, node:AST.BOp, args): - self.visit(node.expr1, args) - self.visit(node.expr2, args) - self.node_to_secret[node] = self.node_to_secret[node.expr1] | self.node_to_secret[node.expr1] - - def visitFunc(self, node:AST.Func, args): - self.visit(node.expr, args) - self.node_to_secret[node] = self.node_to_secret[node.expr] - - - def visitArgMax(self, node:AST.ArgMax, args): - self.visit(node.expr, args) - self.visit(node.dim, args) - self.node_to_secret[node] = self.node_to_secret[node.expr] | self.node_to_secret[node.dim] - - def visitReduce(self, node:AST.Reduce, args): - self.visit(node.expr, args) - self.node_to_secret[node] = self.node_to_secret[node.expr] - - def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args): - self.visit(node.expr, args) - self.visit(node.multExpr, args) - self.visit(node.addExpr, args) - self.node_to_secret[node] = self.node_to_secret[node.expr] | self.node_to_secret[node.multExpr] | self.node_to_secret[node.addExpr] - - -# A very basic alias analysis pass which creates alias sets for variables created + def __init__(self): + self.idf_to_secret = {} + self.node_to_secret = {} + + def isSecret(self, idf: str): + return self.idf_to_secret[idf] + + def visitInt(self, node: AST.Int, args): + self.node_to_secret[node] = node.isSecret + + def visitFloat(self, node: AST.Float, args): + self.node_to_secret[node] = node.isSecret + + def visitInput(self, node: AST.Input, args): + self.node_to_secret[node] = node.isSecret + + def visitId(self, node: AST.ID, args): + self.node_to_secret[node] = self.idf_to_secret[node.name] + + def visitLet(self, node: AST.Let, args): + self.visit(node.decl, args) + self.idf_to_secret[node.name.name] = self.node_to_secret[node.decl] + self.visit(node.expr, args) + + def visitDecl(self, node: AST.Decl, args): + self.node_to_secret[node] = node.isSecret + if node.valueList: + for elem in node.valueList: + self.visit(elem, args) + + def visitUninterpFuncCall(self, node: AST.UninterpFuncCall, args): + self.node_to_secret[node] = node.isSecret + for elem in node.argsList: + self.visit(elem, args) + + def visitTranspose(self, node: AST.Transpose, args): + self.visit(node.expr, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] + + def visitSlice(self, node: AST.Slice, args): + self.visit(node.expr, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] + + def visitReshape(self, node: AST.Reshape, args): + self.visit(node.expr, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] + + def visitPool(self, node: AST.Pool, args): + self.visit(node.expr, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] + + def visitUOp(self, node: AST.UOp, args): + self.visit(node.expr, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] + + def visitBOp(self, node: AST.BOp, args): + self.visit(node.expr1, args) + self.visit(node.expr2, args) + self.node_to_secret[node] = ( + self.node_to_secret[node.expr1] | self.node_to_secret[node.expr1] + ) + + def visitFunc(self, node: AST.Func, args): + self.visit(node.expr, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] + + def visitArgMax(self, node: AST.ArgMax, args): + self.visit(node.expr, args) + self.visit(node.dim, args) + self.node_to_secret[node] = ( + self.node_to_secret[node.expr] | self.node_to_secret[node.dim] + ) + + def visitReduce(self, node: AST.Reduce, args): + self.visit(node.expr, args) + self.node_to_secret[node] = self.node_to_secret[node.expr] + + def visitFusedBatchNorm(self, node: AST.FusedBatchNorm, args): + self.visit(node.expr, args) + self.visit(node.multExpr, args) + self.visit(node.addExpr, args) + self.node_to_secret[node] = ( + self.node_to_secret[node.expr] + | self.node_to_secret[node.multExpr] + | self.node_to_secret[node.addExpr] + ) + + +# A very basic alias analysis pass which creates alias sets for variables created # through identity ops # let a = b class AliasAnalysis(ASTVisitor): - def __init__(self): - self.alias_sets = Util.DisjointSet() - super().__init__() + def __init__(self): + self.alias_sets = Util.DisjointSet() + super().__init__() + + def add_alias(self, inp1, inp2): + self.alias_sets.make_set(inp1) + self.alias_sets.make_set(inp2) + self.alias_sets.union(inp1, inp2) - def add_alias(self, inp1, inp2): - self.alias_sets.make_set(inp1) - self.alias_sets.make_set(inp2) - self.alias_sets.union(inp1, inp2) + def get_alias_set(self, inp): + return self.alias_sets.get_key_set(inp) - def get_alias_set(self, inp): - return self.alias_sets.get_key_set(inp) + def visitLet(self, node: AST.Let, args): + self.visit(node.decl) + self.visit(node.expr) - def visitLet(self, node:AST.Let, args): - self.visit(node.decl) - self.visit(node.expr) + # Two IDs with same name can have diff pointers. Hence we store ID names instead of pointers. + if isinstance(node.decl, AST.ID): + self.add_alias(node.name.name, node.decl.name) - # Two IDs with same name can have diff pointers. Hence we store ID names instead of pointers. - if isinstance(node.decl, AST.ID): - self.add_alias(node.name.name, node.decl.name) -''' +""" We visit the program bottom up. Every time we encounter a use of a variable, we insert a free instruction after it, unless the variable has already been freed. We are basically freeing variables after their last use. @@ -157,116 +165,130 @@ def visitLet(self, node:AST.Let, args): free(J100) .. -''' +""" + + class GarbageCollector(ASTVisitor): - def __init__(self, ast): - self.ast = ast - self.secret_analysis = SecretFlowAnalysis() - self.secret_analysis.visit(self.ast) - self.alias_analysis = AliasAnalysis() - self.alias_analysis.visit(self.ast) - self.freed_nodes = set() - self.counter = 0 - super().__init__() - - def run(self, args): - self.visit(self.ast, args) - - def isVarFreed(self, inp): - alias_set = self.alias_analysis.get_alias_set(inp) - if alias_set is None: - return inp in self.freed_nodes - for i in alias_set: - if i in self.freed_nodes: - return True - return False - - def visitLet(self, node:AST.Let, args): - assert(isinstance(args, list)) - assert(isinstance(args[0], MtdAST)) - - self.visit(node.expr, args) - - usedVars = self.visit(node.decl, args) - if usedVars is None: - assert False, " visit of {} not implemented in GarbageCollector pass".format(str(type(node.decl))) - - varsToDeAllocate = [i for i in usedVars if not self.isVarFreed(i)] - self.freed_nodes = self.freed_nodes.union(set(varsToDeAllocate)) - - astSubTree = node.expr - mtdForNewASTNodes = {AST.ASTNode.mtdKeyTFOpName : "No-op: ClearMem", - AST.ASTNode.mtdKeyTFNodeName : ""} - for ii, curVarName in enumerate(varsToDeAllocate): - newSubTree = AST.Let(AST.ID("cv"+str(self.counter+ii)), - AST.Func(AST.Operators.ClearMemSecret if self.secret_analysis.isSecret(curVarName) else AST.Operators.ClearMemPublic, - AST.ID(curVarName)), - AST.ID("")) - self.counter += 1 - args[0].visit(newSubTree, mtdForNewASTNodes) - newSubTree.expr = astSubTree - node.expr = newSubTree - astSubTree = node.expr - - def visitInt(self, node:AST.Int, args): - return set() - - def visitFloat(self, node:AST.Float, args): - return set() - - def visitInput(self, node:AST.Input, args): - return set() - - def visitId(self, node:AST.ID, args): - return set([node.name]) - - def visitDecl(self, node:AST.Decl, args): - return set() - - def visitTranspose(self, node:AST.Transpose, args): - usedVars = self.visit(node.expr, args) - return usedVars - - def visitSlice(self, node:AST.Slice, args): - usedVars = self.visit(node.expr, args) - return usedVars - - def visitReshape(self, node:AST.Reshape, args): - usedVars = self.visit(node.expr, args) - return usedVars - - def visitPool(self, node:AST.Pool, args): - usedVars = self.visit(node.expr, args) - return usedVars - - def visitUOp(self, node:AST.UOp, args): - usedVars = self.visit(node.expr, args) - return usedVars - - def visitBOp(self, node:AST.BOp, args): - usedVars = self.visit(node.expr1, args) | self.visit(node.expr2, args) - return usedVars - - def visitFunc(self, node:AST.Func, args): - usedVars = self.visit(node.expr, args) - return usedVars - - def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args): - usedVars = set([]) - for elem in node.argsList: - usedVars |= self.visit(elem, args) - return usedVars - - def visitArgMax(self, node:AST.ArgMax, args): - usedVars = self.visit(node.expr, args) | self.visit(node.dim, args) - return usedVars - - def visitReduce(self, node:AST.Reduce, args): - usedVars = self.visit(node.expr, args) - return usedVars - - def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args): - usedVars = self.visit(node.expr, args) - usedVars |= self.visit(node.multExpr, args) - usedVars |= self.visit(node.addExpr, args) - return usedVars \ No newline at end of file + def __init__(self, ast): + self.ast = ast + self.secret_analysis = SecretFlowAnalysis() + self.secret_analysis.visit(self.ast) + self.alias_analysis = AliasAnalysis() + self.alias_analysis.visit(self.ast) + self.freed_nodes = set() + self.counter = 0 + super().__init__() + + def run(self, args): + self.visit(self.ast, args) + + def isVarFreed(self, inp): + alias_set = self.alias_analysis.get_alias_set(inp) + if alias_set is None: + return inp in self.freed_nodes + for i in alias_set: + if i in self.freed_nodes: + return True + return False + + def visitLet(self, node: AST.Let, args): + assert isinstance(args, list) + assert isinstance(args[0], MtdAST) + + self.visit(node.expr, args) + + usedVars = self.visit(node.decl, args) + if usedVars is None: + assert ( + False + ), " visit of {} not implemented in GarbageCollector pass".format( + str(type(node.decl)) + ) + + varsToDeAllocate = [i for i in usedVars if not self.isVarFreed(i)] + self.freed_nodes = self.freed_nodes.union(set(varsToDeAllocate)) + + astSubTree = node.expr + mtdForNewASTNodes = { + AST.ASTNode.mtdKeyTFOpName: "No-op: ClearMem", + AST.ASTNode.mtdKeyTFNodeName: "", + } + for ii, curVarName in enumerate(varsToDeAllocate): + newSubTree = AST.Let( + AST.ID("cv" + str(self.counter + ii)), + AST.Func( + AST.Operators.ClearMemSecret + if self.secret_analysis.isSecret(curVarName) + else AST.Operators.ClearMemPublic, + AST.ID(curVarName), + ), + AST.ID(""), + ) + self.counter += 1 + args[0].visit(newSubTree, mtdForNewASTNodes) + newSubTree.expr = astSubTree + node.expr = newSubTree + astSubTree = node.expr + + def visitInt(self, node: AST.Int, args): + return set() + + def visitFloat(self, node: AST.Float, args): + return set() + + def visitInput(self, node: AST.Input, args): + return set() + + def visitId(self, node: AST.ID, args): + return set([node.name]) + + def visitDecl(self, node: AST.Decl, args): + return set() + + def visitTranspose(self, node: AST.Transpose, args): + usedVars = self.visit(node.expr, args) + return usedVars + + def visitSlice(self, node: AST.Slice, args): + usedVars = self.visit(node.expr, args) + return usedVars + + def visitReshape(self, node: AST.Reshape, args): + usedVars = self.visit(node.expr, args) + return usedVars + + def visitPool(self, node: AST.Pool, args): + usedVars = self.visit(node.expr, args) + return usedVars + + def visitUOp(self, node: AST.UOp, args): + usedVars = self.visit(node.expr, args) + return usedVars + + def visitBOp(self, node: AST.BOp, args): + usedVars = self.visit(node.expr1, args) | self.visit(node.expr2, args) + return usedVars + + def visitFunc(self, node: AST.Func, args): + usedVars = self.visit(node.expr, args) + return usedVars + + def visitUninterpFuncCall(self, node: AST.UninterpFuncCall, args): + usedVars = set([]) + for elem in node.argsList: + usedVars |= self.visit(elem, args) + return usedVars + + def visitArgMax(self, node: AST.ArgMax, args): + usedVars = self.visit(node.expr, args) | self.visit(node.dim, args) + return usedVars + + def visitReduce(self, node: AST.Reduce, args): + usedVars = self.visit(node.expr, args) + return usedVars + + def visitFusedBatchNorm(self, node: AST.FusedBatchNorm, args): + usedVars = self.visit(node.expr, args) + usedVars |= self.visit(node.multExpr, args) + usedVars |= self.visit(node.addExpr, args) + return usedVars diff --git a/Athos/SeeDot/Optimizations/ReluMaxpoolOpti.py b/Athos/SeeDot/Optimizations/ReluMaxpoolOpti.py index 46718283..21a7a664 100644 --- a/Athos/SeeDot/Optimizations/ReluMaxpoolOpti.py +++ b/Athos/SeeDot/Optimizations/ReluMaxpoolOpti.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,29 +20,31 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import AST.AST as AST from AST.ASTVisitor import ASTVisitor -class ReluMaxpoolOpti(ASTVisitor): - def visitLet(self, node:AST.Let, args:dict): - if isinstance(node.decl, AST.Func) and node.decl.op==AST.Operators.RELU: - # Relu declaration entered - if isinstance(node.expr, AST.Let): - # There is a next let statement - if isinstance(node.expr.decl, AST.Pool) and (node.expr.decl.poolType==AST.Pool.PoolType.MaxPool): - # This is the case of relu followed by maxpool declaration - # Switch here - print("Found relu followed by maxpool. Performing optimization.") - # Assuming here that only maxpool's output is subsequently used. - # TODO: Do something for above? - reluDecl, maxpoolDecl = node.decl, node.expr.decl - maxpoolDecl.expr = reluDecl.expr - reluDecl.expr = node.name - node.decl = maxpoolDecl - node.expr.decl = reluDecl - self.visit(node.name) - self.visit(node.decl) - self.visit(node.expr) +class ReluMaxpoolOpti(ASTVisitor): + def visitLet(self, node: AST.Let, args: dict): + if isinstance(node.decl, AST.Func) and node.decl.op == AST.Operators.RELU: + # Relu declaration entered + if isinstance(node.expr, AST.Let): + # There is a next let statement + if isinstance(node.expr.decl, AST.Pool) and ( + node.expr.decl.poolType == AST.Pool.PoolType.MaxPool + ): + # This is the case of relu followed by maxpool declaration + # Switch here + print("Found relu followed by maxpool. Performing optimization.") + # Assuming here that only maxpool's output is subsequently used. + # TODO: Do something for above? + reluDecl, maxpoolDecl = node.decl, node.expr.decl + maxpoolDecl.expr = reluDecl.expr + reluDecl.expr = node.name + node.decl = maxpoolDecl + node.expr.decl = reluDecl + self.visit(node.name) + self.visit(node.decl) + self.visit(node.expr) diff --git a/Athos/SeeDot/SeeDot.py b/Athos/SeeDot/SeeDot.py index a45a8321..ba20a5c7 100644 --- a/Athos/SeeDot/SeeDot.py +++ b/Athos/SeeDot/SeeDot.py @@ -1,4 +1,4 @@ -''' +""" Authors: Sridhar Gopinath, Nishant Kumar. @@ -20,75 +20,137 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import os, sys import argparse import Util from Compiler import Compiler + class MainDriver: - def parseArgs(self): - def str2bool(v): - if isinstance(v, bool): - return v - if v.lower() in ('true'): - return True - elif v.lower() in ('false'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - parser = argparse.ArgumentParser() + def parseArgs(self): + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("true"): + return True + elif v.lower() in ("false"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + parser = argparse.ArgumentParser() + + parser.add_argument( + "-v", + "--version", + choices=Util.Version.All, + default=Util.Version.Fixed, + metavar="", + help="Floating point code or fixed point code", + ) + parser.add_argument( + "-t", + "--target", + choices=Util.Target.All, + default=Util.Target.EzPC, + metavar="", + help="EzPC code or something else", + ) + parser.add_argument( + "--sfType", + choices=Util.SFType.All, + default=Util.SFType.Constant, + metavar="", + help="Use constant/variable SF", + ) + parser.add_argument("--astFile", help="Load AST from this file") + parser.add_argument( + "-p", "--printAST", default=False, type=bool, help="Print the AST or not." + ) + parser.add_argument( + "--consSF", default=15, type=int, help="Use this constant scaling factor." + ) + parser.add_argument( + "--bitlen", + default=64, + type=int, + help="Bitlength to compile to. Defaults to 64.", + ) + parser.add_argument( + "--disableRMO", + default=False, + type=str2bool, + help="Disable Relu-Maxpool optimization.", + ) + parser.add_argument( + "--disableLivenessOpti", + default=False, + type=str2bool, + help="Disable liveness optimization.", + ) + parser.add_argument( + "--disableTruncOpti", + default=False, + type=str2bool, + help="Disable truncation placement optimization.", + ) + parser.add_argument( + "--disableAllOpti", + default=False, + type=str2bool, + help="Disable all optimizations.", + ) + parser.add_argument( + "--outputFileName", + help="Name of the output file with extension (Donot include folder path).", + ) + parser.add_argument( + "--debugVar", type=str, help="Name of the onnx node to be debugged" + ) + + self.args = parser.parse_args() - parser.add_argument("-v", "--version", choices=Util.Version.All, default=Util.Version.Fixed, metavar='', help="Floating point code or fixed point code") - parser.add_argument("-t", "--target", choices=Util.Target.All, default=Util.Target.EzPC, metavar='', help="EzPC code or something else") - parser.add_argument("--sfType", choices=Util.SFType.All, default=Util.SFType.Constant, metavar='', help="Use constant/variable SF" ) - parser.add_argument("--astFile", help="Load AST from this file" ) - parser.add_argument("-p", "--printAST", default=False, type=bool, help="Print the AST or not.") - parser.add_argument("--consSF", default=15, type=int, help="Use this constant scaling factor.") - parser.add_argument("--bitlen", default=64, type=int, help="Bitlength to compile to. Defaults to 64.") - parser.add_argument("--disableRMO", default=False, type=str2bool, help="Disable Relu-Maxpool optimization.") - parser.add_argument("--disableLivenessOpti", default=False, type=str2bool, help="Disable liveness optimization.") - parser.add_argument("--disableTruncOpti", default=False, type=str2bool, help="Disable truncation placement optimization.") - parser.add_argument("--disableAllOpti", default=False, type=str2bool, help="Disable all optimizations.") - parser.add_argument("--outputFileName", help="Name of the output file with extension (Donot include folder path).") - parser.add_argument("--debugVar", type=str, help="Name of the onnx node to be debugged") - - self.args = parser.parse_args() + def runCompilerDriver(self): + print( + "Generating {0} point code for {1} target with sfType={2}, consSF={3} and bitlen={4}.".format( + self.args.version, + self.args.target, + self.args.sfType, + self.args.consSF, + self.args.bitlen, + ) + ) + if self.args.disableAllOpti: + print("Running with all optimizations disabled.") + elif self.args.disableRMO: + print("Running with Relu-Maxpool optimization disabled.") + elif self.args.disableLivenessOpti: + print("Running with liveness optimization disabled.") + elif self.args.disableTruncOpti: + print("Running with truncation placement optimization disabled.") - def runCompilerDriver(self): - print("Generating {0} point code for {1} target with sfType={2}, consSF={3} and bitlen={4}.".format(self.args.version, - self.args.target, - self.args.sfType, - self.args.consSF, - self.args.bitlen)) - if self.args.disableAllOpti: - print("Running with all optimizations disabled.") - elif self.args.disableRMO: - print("Running with Relu-Maxpool optimization disabled.") - elif self.args.disableLivenessOpti: - print("Running with liveness optimization disabled.") - elif self.args.disableTruncOpti: - print("Running with truncation placement optimization disabled.") + obj = Compiler( + self.args.version, + self.args.target, + self.args.sfType, + self.args.astFile, + self.args.printAST, + self.args.consSF, + self.args.bitlen, + self.args.outputFileName, + self.args.disableRMO, + self.args.disableLivenessOpti, + self.args.disableTruncOpti, + self.args.disableAllOpti, + self.args.debugVar, + ) + obj.run() - obj = Compiler(self.args.version, - self.args.target, - self.args.sfType, - self.args.astFile, - self.args.printAST, - self.args.consSF, - self.args.bitlen, - self.args.outputFileName, - self.args.disableRMO, - self.args.disableLivenessOpti, - self.args.disableTruncOpti, - self.args.disableAllOpti, - self.args.debugVar - ) - obj.run() if __name__ == "__main__": - sys.setrecursionlimit(10000) - obj = MainDriver() - obj.parseArgs() - obj.runCompilerDriver() + sys.setrecursionlimit(10000) + obj = MainDriver() + obj.parseArgs() + obj.runCompilerDriver() diff --git a/Athos/SeeDot/Type.py b/Athos/SeeDot/Type.py index 7c37c3f5..4d549a4c 100644 --- a/Athos/SeeDot/Type.py +++ b/Athos/SeeDot/Type.py @@ -1,4 +1,4 @@ -''' +""" Authors: Sridhar Gopinath, Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import Util import operator @@ -30,10 +30,12 @@ from enum import Enum, auto import copy + class Type: - pass + pass + -''' +""" We want to analyse the taint of every tensor that flows in the graph. The possible taints for tensors are: { @@ -53,484 +55,544 @@ class Type: C&S C&S C&S C&S C&S C&S Secret_constant C&S C&S C&S Secret_constant Secret_constant Public_constant Client Server C&S Secret_constant Public_constant -''' +""" + class Taints(Enum): - CLIENT = auto() - SERVER = auto() - CLIENT_SERVER = auto() - SECRET_C = auto() - PUBLIC_C = auto() + CLIENT = auto() + SERVER = auto() + CLIENT_SERVER = auto() + SECRET_C = auto() + PUBLIC_C = auto() -constantTaintsMapping = { True : Taints.SECRET_C, False : Taints.PUBLIC_C} + +constantTaintsMapping = {True: Taints.SECRET_C, False: Taints.PUBLIC_C} TaintsTable = { - Taints.CLIENT : { - Taints.CLIENT : Taints.CLIENT, - Taints.SERVER : Taints.CLIENT_SERVER, - Taints.CLIENT_SERVER: Taints.CLIENT_SERVER, - Taints.SECRET_C: Taints.CLIENT, - Taints.PUBLIC_C: Taints.CLIENT - }, - Taints.SERVER : { - Taints.CLIENT : Taints.CLIENT_SERVER, - Taints.SERVER : Taints.SERVER, - Taints.CLIENT_SERVER: Taints.CLIENT_SERVER, - Taints.SECRET_C: Taints.SERVER, - Taints.PUBLIC_C: Taints.SERVER - }, - Taints.CLIENT_SERVER : { - Taints.CLIENT : Taints.CLIENT_SERVER, - Taints.SERVER : Taints.CLIENT_SERVER, - Taints.CLIENT_SERVER: Taints.CLIENT_SERVER, - Taints.SECRET_C: Taints.CLIENT_SERVER, - Taints.PUBLIC_C: Taints.CLIENT_SERVER - }, - Taints.SECRET_C : { - Taints.CLIENT : Taints.CLIENT, - Taints.SERVER : Taints.SERVER, - Taints.CLIENT_SERVER: Taints.CLIENT_SERVER, - Taints.SECRET_C: Taints.SECRET_C, - Taints.PUBLIC_C: Taints.SECRET_C - }, - Taints.PUBLIC_C : { - Taints.CLIENT : Taints.CLIENT, - Taints.SERVER : Taints.SERVER, - Taints.CLIENT_SERVER: Taints.CLIENT_SERVER, - Taints.SECRET_C: Taints.SECRET_C, - Taints.PUBLIC_C: Taints.PUBLIC_C - } - } + Taints.CLIENT: { + Taints.CLIENT: Taints.CLIENT, + Taints.SERVER: Taints.CLIENT_SERVER, + Taints.CLIENT_SERVER: Taints.CLIENT_SERVER, + Taints.SECRET_C: Taints.CLIENT, + Taints.PUBLIC_C: Taints.CLIENT, + }, + Taints.SERVER: { + Taints.CLIENT: Taints.CLIENT_SERVER, + Taints.SERVER: Taints.SERVER, + Taints.CLIENT_SERVER: Taints.CLIENT_SERVER, + Taints.SECRET_C: Taints.SERVER, + Taints.PUBLIC_C: Taints.SERVER, + }, + Taints.CLIENT_SERVER: { + Taints.CLIENT: Taints.CLIENT_SERVER, + Taints.SERVER: Taints.CLIENT_SERVER, + Taints.CLIENT_SERVER: Taints.CLIENT_SERVER, + Taints.SECRET_C: Taints.CLIENT_SERVER, + Taints.PUBLIC_C: Taints.CLIENT_SERVER, + }, + Taints.SECRET_C: { + Taints.CLIENT: Taints.CLIENT, + Taints.SERVER: Taints.SERVER, + Taints.CLIENT_SERVER: Taints.CLIENT_SERVER, + Taints.SECRET_C: Taints.SECRET_C, + Taints.PUBLIC_C: Taints.SECRET_C, + }, + Taints.PUBLIC_C: { + Taints.CLIENT: Taints.CLIENT, + Taints.SERVER: Taints.SERVER, + Taints.CLIENT_SERVER: Taints.CLIENT_SERVER, + Taints.SECRET_C: Taints.SECRET_C, + Taints.PUBLIC_C: Taints.PUBLIC_C, + }, +} + + def getTaint_taint(t1: Taints, t2: Taints): - return TaintsTable[t1][t2] + return TaintsTable[t1][t2] + def getTaint_type(t1: Type, t2: Type): - return TaintsTable[t1.taint][t2.taint] + return TaintsTable[t1.taint][t2.taint] + class Int(Type): - def __init__(self, bitlen=-1, isSecret=False, taint=Taints.PUBLIC_C): - if bitlen==-1: - self.bitlen = Util.Config.wordLength - else: - self.bitlen = bitlen - self.isSecret = isSecret - self.taint = taint + def __init__(self, bitlen=-1, isSecret=False, taint=Taints.PUBLIC_C): + if bitlen == -1: + self.bitlen = Util.Config.wordLength + else: + self.bitlen = bitlen + self.isSecret = isSecret + self.taint = taint + + def __copy__(self): + return type(self)(self.bitlen, self.isSecret, self.taint) - def __copy__(self): - return type(self)(self.bitlen, self.isSecret, self.taint) class Unit(Type): - pass + pass + class Tensor(Type): - def __init__(self, shape:list, bitlen=-1, isSecret=True, taint=Taints.PUBLIC_C): - self.shape = shape - self.dim = len(shape) - if bitlen==-1: - self.bitlen = Util.Config.wordLength - else: - self.bitlen = bitlen - self.isSecret = isSecret - self.taint = taint - - def __copy__(self): - return type(self)(self.shape, self.bitlen, self.isSecret, self.taint) - - def size(self): - return reduce(operator.mul, self.shape, 1) - - # Tensor without any dimension (float) or a tensor with all dimensions equal to 1 - def isShapeOne(self): - return self.dim == 0 or self.size() == 1 - -def isInt(type:Type): - return isinstance(type, Int) - -def isTensor(type:Type): - return isinstance(type, Tensor) - -def isUnit(type:Type): - return isinstance(type, Unit) - -def isEqual(type1:Type, type2:Type): - if isInt(type1) and isInt(type2): - return True - elif isTensor(type1) and isTensor(type2): - if type1.dim != type2.dim: - return False - return type1.shape == type2.shape - else: - assert False + def __init__(self, shape: list, bitlen=-1, isSecret=True, taint=Taints.PUBLIC_C): + self.shape = shape + self.dim = len(shape) + if bitlen == -1: + self.bitlen = Util.Config.wordLength + else: + self.bitlen = bitlen + self.isSecret = isSecret + self.taint = taint + + def __copy__(self): + return type(self)(self.shape, self.bitlen, self.isSecret, self.taint) + + def size(self): + return reduce(operator.mul, self.shape, 1) + + # Tensor without any dimension (float) or a tensor with all dimensions equal to 1 + def isShapeOne(self): + return self.dim == 0 or self.size() == 1 + + +def isInt(type: Type): + return isinstance(type, Int) + + +def isTensor(type: Type): + return isinstance(type, Tensor) + + +def isUnit(type: Type): + return isinstance(type, Unit) + + +def isEqual(type1: Type, type2: Type): + if isInt(type1) and isInt(type2): + return True + elif isTensor(type1) and isTensor(type2): + if type1.dim != type2.dim: + return False + return type1.shape == type2.shape + else: + assert False + class InferType(ASTVisitor): - def visitInt(self, node:AST.Int, args=None): - bitlen = Util.Config.wordLength - if node.bitLen: - bitlen = node.bitLen - node.type = Int(bitlen, node.isSecret, constantTaintsMapping[node.isSecret]) - return node.type - - def visitFloat(self, node:AST.Float, args=None): - # Float is represented as an int in fixedpt. - node.type = Int(isSecret=node.isSecret, taint=constantTaintsMapping[node.isSecret]) - return node.type - - def visitId(self, node:AST.ID, args=None): - if node.name not in node.gamma: - print("Error in type checking: Found id which is not contained in gamma.", file=sys.stderr) - assert(False) - else: - node.type = node.gamma[node.name] - return node.type - - def visitDecl(self, node:AST.Decl, args=None): - #TODO -- fill in bitlen properly - if (node.shape == []): - node.type = Int(isSecret=node.isSecret, taint=constantTaintsMapping[node.isSecret]) - else: - node.type = Tensor(shape=node.shape, isSecret=node.isSecret, taint=constantTaintsMapping[node.isSecret]) - return node.type - - def visitTranspose(self, node:AST.Transpose, args=None): - node.expr.gamma = dict(node.gamma) - exprType = self.visit(node.expr) - - assert isTensor(exprType) - - perm = node.perm - shape = exprType.shape - if (perm is None): - perm = [i for i in reversed(range(len(shape)))] - new_shape = [] - for i in perm: - new_shape.append(shape[i]) - node.type = Tensor(new_shape, exprType.bitlen, exprType.isSecret, exprType.taint) - return node.type - - def visitSlice(self, node:AST.Slice, args=None): - node.expr.gamma = dict(node.gamma) - exprType = self.visit(node.expr) - assert isTensor(exprType) - - subscriptRanges = node.subscriptRanges - shape = [] - for i in subscriptRanges: - start = i[0] - end = i[1] - size = end - start + 1 - shape.append(size) - - assert(len(shape) == len(exprType.shape)) - for i in range(0,len(shape)): - assert shape[i] <= exprType.shape[i], " for {}".format(node.metadata) - - node.type = Tensor(shape, exprType.bitlen, exprType.isSecret, exprType.taint) - return node.type - - def visitReshape(self, node:AST.Reshape, args=None): - node.expr.gamma = dict(node.gamma) - exprType = self.visit(node.expr) - - assert isTensor(exprType) and exprType.dim >= 1 - - # Reshape is valid if the total number of elements remain same after reshape - assert reduce(operator.mul, exprType.shape, 1) == reduce(operator.mul, node.shape, 1) - node.type = Tensor(node.shape, exprType.bitlen, exprType.isSecret, exprType.taint) - - return node.type - - def visitPool(self, node:AST.Pool, args=None): - node.expr.gamma = dict(node.gamma) - exprType = self.visit(node.expr) - - # Implementation only performs maxpool over a 4D input - assert isTensor(exprType) and exprType.dim == 4 - [N, H, W, CI] = exprType.shape - FH = node.options[AST.PaddingKeysDict.FH] - FW = node.options[AST.PaddingKeysDict.FW] - zPadHLeft = node.options[AST.PaddingKeysDict.zPadHLeft] - zPadHRight = node.options[AST.PaddingKeysDict.zPadHRight] - zPadWLeft = node.options[AST.PaddingKeysDict.zPadWLeft] - zPadWRight = node.options[AST.PaddingKeysDict.zPadWRight] - strideH = node.options[AST.PaddingKeysDict.strideH] - strideW = node.options[AST.PaddingKeysDict.strideW] - - newH = ((H + zPadHLeft + zPadHRight - FH)//strideH) + 1 - newW = ((W + zPadWLeft + zPadWRight - FW)//strideW) + 1 - - node.type = Tensor([N, newH, newW, CI], exprType.bitlen, exprType.isSecret, exprType.taint) - - return node.type - - def visitUOp(self, node:AST.UOp, args=None): - node.expr.gamma = dict(node.gamma) - node.type = self.visit(node.expr) - return node.type - - def visitBOp(self, node:AST.BOp, args=None): - node.expr1.gamma = dict(node.gamma) - eType = self.visit(node.expr1) - - node.expr2.gamma = dict(node.gamma) - fType = self.visit(node.expr2) - - if node.op in [AST.Operators.ADD, AST.Operators.SUB, AST.Operators.Equal, AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv]: - # Ops supporting broadcasting - return self.typeCheckBroadcastOps(node, eType, fType) - elif node.op == AST.Operators.MUL: - return self.visitBopMul(node, eType, fType) - elif node.op == AST.Operators.CONV: - return self.visitBopConv(node, eType, fType) - elif node.op == AST.Operators.CONVTRANSPOSE: - return self.visitBopConvTranspose(node, eType, fType) - else: - assert False - - def typeCheckBroadcastOps(self, node:AST.BOp, eType:Type, fType:Type): - # Ops which support broadcasting have different type checking - # If adding a new op here which supports broadcasting, then be careful! - # Currently, its assumed the op is commutative. If that is not true, following will be wrong ! - - assert node.op in [AST.Operators.ADD, AST.Operators.SUB, AST.Operators.Equal, AST.Operators.ElemWiseMul, AST.Operators.ElemWiseDiv] - if isInt(eType) and isInt(fType): - node.type = Int(eType.bitlen) - elif isTensor(eType) and isTensor(fType): - output_shape, _, _ = Util.getBroadcastShapes(eType.shape, fType.shape) - node.type = Tensor(shape=output_shape, bitlen=eType.bitlen) - elif isTensor(eType) and isInt(fType): - output_shape, _, _ = Util.getBroadcastShapes(eType.shape, []) - node.type = Tensor(shape=output_shape, bitlen=eType.bitlen) - elif isInt(eType) and isTensor(fType): - output_shape, _, _ = Util.getBroadcastShapes([], fType.shape) - node.type = Tensor(shape=output_shape, bitlen=eType.bitlen) - else: - print(eType, fType) - assert False - - node.type.taint = getTaint_type(eType, fType) - node.type.isSecret = eType.isSecret | fType.isSecret - return node.type - - def visitBopMul(self, node:AST.BOp, eType:Type, fType:Type, args=None): - if isInt(eType) and isInt(fType): - node.type = Int(eType.bitlen, eType.isSecret) - elif isTensor(eType) and isTensor(fType): - if eType.dim == 0: - node.type = copy.copy(fType) - elif fType.dim == 0: - node.type = copy.copy(eType) - else: - assert eType.dim == 2 and fType.dim == 2 - [n1, n2] = eType.shape - [n3, n4] = fType.shape - assert n2 == n3 - node.type = Tensor([n1, n4], eType.bitlen) - else: - print("Error: Unknown condition in type checking.", file=sys.stderr) - assert(False) - - node.type.taint = getTaint_type(eType, fType) - node.type.isSecret = eType.isSecret | fType.isSecret - - return node.type - - def visitBopConv(self, node:AST.BOp, eType:Type, fType:Type, args=None): - assert isTensor(eType) and isTensor(fType) - convDim = 2 - group = 1 - if AST.PaddingKeysDict.ConvDim in node.options: - convDim = node.options[AST.PaddingKeysDict.ConvDim] - - if convDim==2: - assert eType.dim == 4 and fType.dim == 4 - elif convDim==3: - assert eType.dim == 5 and fType.dim == 5 - else: - assert(False) - - N = D = H = W = CI = FD = FH = FW = CI1 = CO = -1 - newD = -1 - if (convDim == 2): - [N, H, W, CI] = eType.shape - [FH, FW, CI1, CO] = fType.shape - elif (convDim == 3): - [N, D, H, W, CI] = eType.shape - [FD, FH, FW, CI1, CO] = fType.shape - assert(FD == node.options[AST.PaddingKeysDict.FD]) - zPadDLeft = node.options[AST.PaddingKeysDict.zPadDLeft] - zPadDRight = node.options[AST.PaddingKeysDict.zPadDRight] - strideD = node.options[AST.PaddingKeysDict.strideD] - - newD = ((D + zPadDLeft + zPadDRight - FD)//strideD) + 1 - else: - assert(False) - - if AST.PaddingKeysDict.group in node.options: - group = node.options[AST.PaddingKeysDict.group] - - assert(FH == node.options[AST.PaddingKeysDict.FH]) - assert(FW == node.options[AST.PaddingKeysDict.FW]) - assert CI1*group == CI, "FCI={} group={} CI={}".format(CI1, group, CI) - zPadHLeft = node.options[AST.PaddingKeysDict.zPadHLeft] - zPadHRight = node.options[AST.PaddingKeysDict.zPadHRight] - zPadWLeft = node.options[AST.PaddingKeysDict.zPadWLeft] - zPadWRight = node.options[AST.PaddingKeysDict.zPadWRight] - strideH = node.options[AST.PaddingKeysDict.strideH] - strideW = node.options[AST.PaddingKeysDict.strideW] - - newH = ((H + zPadHLeft + zPadHRight - FH)//strideH) + 1 - newW = ((W + zPadWLeft + zPadWRight - FW)//strideW) + 1 - - if convDim == 2: - shape = [N, newH, newW, CO] - elif convDim == 3: - shape = [N, newD, newH, newW, CO] - node.type = Tensor(shape, eType.bitlen, eType.isSecret | fType.isSecret, getTaint_type(eType, fType)) - return node.type - - def visitBopConvTranspose(self, node:AST.BOp, eType:Type, fType:Type, args=None): - assert isTensor(eType) and isTensor(fType) - - convDim = 2 - if AST.PaddingKeysDict.ConvDim in node.options: - convDim = node.options[AST.PaddingKeysDict.ConvDim] - - if convDim==2: - [N, HP, WP, CI1] = eType.shape - [FH, FW, CO, CI] = fType.shape - elif convDim==3: - [N, DP, HP, WP, CI1] = eType.shape - [FD, FH, FW, CO, CI] = fType.shape - else: - assert(False) - assert(CI1 == CI) - if convDim==3: - outputImgD = node.options[AST.PaddingKeysDict.outputImgD] - outputImgH = node.options[AST.PaddingKeysDict.outputImgH] - outputImgW = node.options[AST.PaddingKeysDict.outputImgW] - - if convDim==2: - shape = [N, outputImgH, outputImgW, CO] - else: - shape = [N, outputImgD, outputImgH, outputImgW, CO] - - # Logic explanation: - # ConvTranpose can be thought of as the inverse of some convolution for which it is doing the upsampling. - # For calculation of padding in the convTranspose operation, the output image size is required. - # This is why TF also mandates the operator to be specified with output size. - # This conv transpose operation can be thought of as conv between output - # of size shape = [N, outputImgH, outputImgW, CI], and filter of size [FH, FW, CI, CO]. - # Hence, the input for this convTranspose would be [N, HP, WP, CO] - - node.type = Tensor(shape, eType.bitlen, eType.isSecret | fType.isSecret, getTaint_type(eType, fType)) - return node.type - - def visitFunc(self, node:AST.Func, args=None): - node.expr.gamma = dict(node.gamma) - eType = self.visit(node.expr) - - if node.op == AST.Operators.RELU: - assert isTensor(eType) and eType.dim >= 1 - node.type = copy.copy(eType) - elif node.op == AST.Operators.TANH: - assert isTensor(eType) - node.type = copy.copy(eType) - elif node.op == AST.Operators.SIGMOID: - assert isTensor(eType) - node.type = copy.copy(eType) - elif node.op == AST.Operators.SQRT: - assert isTensor(eType) - node.type = copy.copy(eType) - elif node.op == AST.Operators.RSQRT: - assert isTensor(eType) - node.type = copy.copy(eType) - elif node.op == AST.Operators.Floor: - node.type = copy.copy(eType) - elif node.op == AST.Operators.Shape: - assert isTensor(eType) - node.type = Tensor([len(eType.shape)], eType.bitlen, eType.isSecret, eType.taint) - elif node.op == AST.Operators.ClearMemSecret: - node.type = Unit() - elif node.op == AST.Operators.ClearMemPublic: - node.type = Unit() - else: - print("Type inference not implemented for", node.op) - assert False - - return node.type - - def visitLet(self, node:AST.Let, args=None): - node.decl.gamma = dict(node.gamma) - eType = self.visit(node.decl) - - node.name.gamma = { node.name.name : eType} - self.visit(node.name) - - node.expr.gamma = dict(node.gamma) - node.expr.gamma[node.name.name] = eType - fType = self.visit(node.expr) - - node.type = copy.copy(fType) - return node.type - - def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args=None): - # Assert that outputShape and inputDims are lists of int astNode. - assert(len(node.argsList) > 0) - isSecret = False - taint = Taints.PUBLIC_C - for curArg in node.argsList: - curArg.gamma = dict(node.gamma) - eType = self.visit(curArg) #This should set the type of each of the input nodes - isSecret = isSecret | eType.isSecret - taint = getTaint_taint(taint, eType.taint) - outputShape = node.outputShape - node.type = Tensor(outputShape, isSecret=isSecret, taint=taint) - return node.type - - def visitArgMax(self, node:AST.ArgMax, args=None): - node.expr.gamma = dict(node.gamma) - eType = self.visit(node.expr) - - node.dim.gamma = dict(node.gamma) - dimType = self.visit(node.dim) - assert(isInt(dimType) or (isTensor(dimType) and (len(dimType.shape)==0))) - - node.type = Tensor(node.outputShape, eType.bitlen, eType.isSecret, eType.taint) - return node.type - - def visitReduce(self, node:AST.Reduce, args=None): - cur_gamma = dict(node.gamma) - node.expr.gamma = cur_gamma - eType = self.visit(node.expr) - - node.type = Tensor(node.outShape, eType.bitlen, eType.isSecret, eType.taint) - return node.type - - def visitInput(self, node:AST.Input, args=None): - node.type = Tensor(node.shape, isSecret=node.isSecret, taint=Taints[node.inputByParty.name]) - return node.type - - def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args=None): - cur_gamma = dict(node.gamma) - node.expr.gamma = cur_gamma - node.multExpr.gamma = cur_gamma - node.addExpr.gamma = cur_gamma - - exprType = self.visit(node.expr) - multExprType = self.visit(node.multExpr) - addExprType = self.visit(node.addExpr) - - assert(len(multExprType.shape)==1) - assert(len(addExprType.shape)==1) - - [C1] = multExprType.shape - [C2] = addExprType.shape - - assert(exprType.shape[-1]==C1 and C1==C2) - - taint = getTaint_taint(exprType.taint, multExprType.taint) - taint = getTaint_taint(taint, addExprType.taint) - - node.type = copy.copy(exprType) - node.type.taint = taint - return node.type \ No newline at end of file + def visitInt(self, node: AST.Int, args=None): + bitlen = Util.Config.wordLength + if node.bitLen: + bitlen = node.bitLen + node.type = Int(bitlen, node.isSecret, constantTaintsMapping[node.isSecret]) + return node.type + + def visitFloat(self, node: AST.Float, args=None): + # Float is represented as an int in fixedpt. + node.type = Int( + isSecret=node.isSecret, taint=constantTaintsMapping[node.isSecret] + ) + return node.type + + def visitId(self, node: AST.ID, args=None): + if node.name not in node.gamma: + print( + "Error in type checking: Found id which is not contained in gamma.", + file=sys.stderr, + ) + assert False + else: + node.type = node.gamma[node.name] + return node.type + + def visitDecl(self, node: AST.Decl, args=None): + # TODO -- fill in bitlen properly + if node.shape == []: + node.type = Int( + isSecret=node.isSecret, taint=constantTaintsMapping[node.isSecret] + ) + else: + node.type = Tensor( + shape=node.shape, + isSecret=node.isSecret, + taint=constantTaintsMapping[node.isSecret], + ) + return node.type + + def visitTranspose(self, node: AST.Transpose, args=None): + node.expr.gamma = dict(node.gamma) + exprType = self.visit(node.expr) + + assert isTensor(exprType) + + perm = node.perm + shape = exprType.shape + if perm is None: + perm = [i for i in reversed(range(len(shape)))] + new_shape = [] + for i in perm: + new_shape.append(shape[i]) + node.type = Tensor( + new_shape, exprType.bitlen, exprType.isSecret, exprType.taint + ) + return node.type + + def visitSlice(self, node: AST.Slice, args=None): + node.expr.gamma = dict(node.gamma) + exprType = self.visit(node.expr) + assert isTensor(exprType) + + subscriptRanges = node.subscriptRanges + shape = [] + for i in subscriptRanges: + start = i[0] + end = i[1] + size = end - start + 1 + shape.append(size) + + assert len(shape) == len(exprType.shape) + for i in range(0, len(shape)): + assert shape[i] <= exprType.shape[i], " for {}".format(node.metadata) + + node.type = Tensor(shape, exprType.bitlen, exprType.isSecret, exprType.taint) + return node.type + + def visitReshape(self, node: AST.Reshape, args=None): + node.expr.gamma = dict(node.gamma) + exprType = self.visit(node.expr) + + assert isTensor(exprType) and exprType.dim >= 1 + + # Reshape is valid if the total number of elements remain same after reshape + assert reduce(operator.mul, exprType.shape, 1) == reduce( + operator.mul, node.shape, 1 + ) + node.type = Tensor( + node.shape, exprType.bitlen, exprType.isSecret, exprType.taint + ) + + return node.type + + def visitPool(self, node: AST.Pool, args=None): + node.expr.gamma = dict(node.gamma) + exprType = self.visit(node.expr) + + # Implementation only performs maxpool over a 4D input + assert isTensor(exprType) and exprType.dim == 4 + [N, H, W, CI] = exprType.shape + FH = node.options[AST.PaddingKeysDict.FH] + FW = node.options[AST.PaddingKeysDict.FW] + zPadHLeft = node.options[AST.PaddingKeysDict.zPadHLeft] + zPadHRight = node.options[AST.PaddingKeysDict.zPadHRight] + zPadWLeft = node.options[AST.PaddingKeysDict.zPadWLeft] + zPadWRight = node.options[AST.PaddingKeysDict.zPadWRight] + strideH = node.options[AST.PaddingKeysDict.strideH] + strideW = node.options[AST.PaddingKeysDict.strideW] + + newH = ((H + zPadHLeft + zPadHRight - FH) // strideH) + 1 + newW = ((W + zPadWLeft + zPadWRight - FW) // strideW) + 1 + + node.type = Tensor( + [N, newH, newW, CI], exprType.bitlen, exprType.isSecret, exprType.taint + ) + + return node.type + + def visitUOp(self, node: AST.UOp, args=None): + node.expr.gamma = dict(node.gamma) + node.type = self.visit(node.expr) + return node.type + + def visitBOp(self, node: AST.BOp, args=None): + node.expr1.gamma = dict(node.gamma) + eType = self.visit(node.expr1) + + node.expr2.gamma = dict(node.gamma) + fType = self.visit(node.expr2) + + if node.op in [ + AST.Operators.ADD, + AST.Operators.SUB, + AST.Operators.Equal, + AST.Operators.ElemWiseMul, + AST.Operators.ElemWiseDiv, + ]: + # Ops supporting broadcasting + return self.typeCheckBroadcastOps(node, eType, fType) + elif node.op == AST.Operators.MUL: + return self.visitBopMul(node, eType, fType) + elif node.op == AST.Operators.CONV: + return self.visitBopConv(node, eType, fType) + elif node.op == AST.Operators.CONVTRANSPOSE: + return self.visitBopConvTranspose(node, eType, fType) + else: + assert False + + def typeCheckBroadcastOps(self, node: AST.BOp, eType: Type, fType: Type): + # Ops which support broadcasting have different type checking + # If adding a new op here which supports broadcasting, then be careful! + # Currently, its assumed the op is commutative. If that is not true, following will be wrong ! + + assert node.op in [ + AST.Operators.ADD, + AST.Operators.SUB, + AST.Operators.Equal, + AST.Operators.ElemWiseMul, + AST.Operators.ElemWiseDiv, + ] + if isInt(eType) and isInt(fType): + node.type = Int(eType.bitlen) + elif isTensor(eType) and isTensor(fType): + output_shape, _, _ = Util.getBroadcastShapes(eType.shape, fType.shape) + node.type = Tensor(shape=output_shape, bitlen=eType.bitlen) + elif isTensor(eType) and isInt(fType): + output_shape, _, _ = Util.getBroadcastShapes(eType.shape, []) + node.type = Tensor(shape=output_shape, bitlen=eType.bitlen) + elif isInt(eType) and isTensor(fType): + output_shape, _, _ = Util.getBroadcastShapes([], fType.shape) + node.type = Tensor(shape=output_shape, bitlen=eType.bitlen) + else: + print(eType, fType) + assert False + + node.type.taint = getTaint_type(eType, fType) + node.type.isSecret = eType.isSecret | fType.isSecret + return node.type + + def visitBopMul(self, node: AST.BOp, eType: Type, fType: Type, args=None): + if isInt(eType) and isInt(fType): + node.type = Int(eType.bitlen, eType.isSecret) + elif isTensor(eType) and isTensor(fType): + if eType.dim == 0: + node.type = copy.copy(fType) + elif fType.dim == 0: + node.type = copy.copy(eType) + else: + assert eType.dim == 2 and fType.dim == 2 + [n1, n2] = eType.shape + [n3, n4] = fType.shape + assert n2 == n3 + node.type = Tensor([n1, n4], eType.bitlen) + else: + print("Error: Unknown condition in type checking.", file=sys.stderr) + assert False + + node.type.taint = getTaint_type(eType, fType) + node.type.isSecret = eType.isSecret | fType.isSecret + + return node.type + + def visitBopConv(self, node: AST.BOp, eType: Type, fType: Type, args=None): + assert isTensor(eType) and isTensor(fType) + convDim = 2 + group = 1 + if AST.PaddingKeysDict.ConvDim in node.options: + convDim = node.options[AST.PaddingKeysDict.ConvDim] + + if convDim == 2: + assert eType.dim == 4 and fType.dim == 4 + elif convDim == 3: + assert eType.dim == 5 and fType.dim == 5 + else: + assert False + + N = D = H = W = CI = FD = FH = FW = CI1 = CO = -1 + newD = -1 + if convDim == 2: + [N, H, W, CI] = eType.shape + [FH, FW, CI1, CO] = fType.shape + elif convDim == 3: + [N, D, H, W, CI] = eType.shape + [FD, FH, FW, CI1, CO] = fType.shape + assert FD == node.options[AST.PaddingKeysDict.FD] + zPadDLeft = node.options[AST.PaddingKeysDict.zPadDLeft] + zPadDRight = node.options[AST.PaddingKeysDict.zPadDRight] + strideD = node.options[AST.PaddingKeysDict.strideD] + + newD = ((D + zPadDLeft + zPadDRight - FD) // strideD) + 1 + else: + assert False + + if AST.PaddingKeysDict.group in node.options: + group = node.options[AST.PaddingKeysDict.group] + + assert FH == node.options[AST.PaddingKeysDict.FH] + assert FW == node.options[AST.PaddingKeysDict.FW] + assert CI1 * group == CI, "FCI={} group={} CI={}".format(CI1, group, CI) + zPadHLeft = node.options[AST.PaddingKeysDict.zPadHLeft] + zPadHRight = node.options[AST.PaddingKeysDict.zPadHRight] + zPadWLeft = node.options[AST.PaddingKeysDict.zPadWLeft] + zPadWRight = node.options[AST.PaddingKeysDict.zPadWRight] + strideH = node.options[AST.PaddingKeysDict.strideH] + strideW = node.options[AST.PaddingKeysDict.strideW] + + newH = ((H + zPadHLeft + zPadHRight - FH) // strideH) + 1 + newW = ((W + zPadWLeft + zPadWRight - FW) // strideW) + 1 + + if convDim == 2: + shape = [N, newH, newW, CO] + elif convDim == 3: + shape = [N, newD, newH, newW, CO] + node.type = Tensor( + shape, + eType.bitlen, + eType.isSecret | fType.isSecret, + getTaint_type(eType, fType), + ) + return node.type + + def visitBopConvTranspose(self, node: AST.BOp, eType: Type, fType: Type, args=None): + assert isTensor(eType) and isTensor(fType) + + convDim = 2 + if AST.PaddingKeysDict.ConvDim in node.options: + convDim = node.options[AST.PaddingKeysDict.ConvDim] + + if convDim == 2: + [N, HP, WP, CI1] = eType.shape + [FH, FW, CO, CI] = fType.shape + elif convDim == 3: + [N, DP, HP, WP, CI1] = eType.shape + [FD, FH, FW, CO, CI] = fType.shape + else: + assert False + assert CI1 == CI + if convDim == 3: + outputImgD = node.options[AST.PaddingKeysDict.outputImgD] + outputImgH = node.options[AST.PaddingKeysDict.outputImgH] + outputImgW = node.options[AST.PaddingKeysDict.outputImgW] + + if convDim == 2: + shape = [N, outputImgH, outputImgW, CO] + else: + shape = [N, outputImgD, outputImgH, outputImgW, CO] + + # Logic explanation: + # ConvTranpose can be thought of as the inverse of some convolution for which it is doing the upsampling. + # For calculation of padding in the convTranspose operation, the output image size is required. + # This is why TF also mandates the operator to be specified with output size. + # This conv transpose operation can be thought of as conv between output + # of size shape = [N, outputImgH, outputImgW, CI], and filter of size [FH, FW, CI, CO]. + # Hence, the input for this convTranspose would be [N, HP, WP, CO] + + node.type = Tensor( + shape, + eType.bitlen, + eType.isSecret | fType.isSecret, + getTaint_type(eType, fType), + ) + return node.type + + def visitFunc(self, node: AST.Func, args=None): + node.expr.gamma = dict(node.gamma) + eType = self.visit(node.expr) + + if node.op == AST.Operators.RELU: + assert isTensor(eType) and eType.dim >= 1 + node.type = copy.copy(eType) + elif node.op == AST.Operators.TANH: + assert isTensor(eType) + node.type = copy.copy(eType) + elif node.op == AST.Operators.SIGMOID: + assert isTensor(eType) + node.type = copy.copy(eType) + elif node.op == AST.Operators.SQRT: + assert isTensor(eType) + node.type = copy.copy(eType) + elif node.op == AST.Operators.RSQRT: + assert isTensor(eType) + node.type = copy.copy(eType) + elif node.op == AST.Operators.Floor: + node.type = copy.copy(eType) + elif node.op == AST.Operators.Shape: + assert isTensor(eType) + node.type = Tensor( + [len(eType.shape)], eType.bitlen, eType.isSecret, eType.taint + ) + elif node.op == AST.Operators.ClearMemSecret: + node.type = Unit() + elif node.op == AST.Operators.ClearMemPublic: + node.type = Unit() + else: + print("Type inference not implemented for", node.op) + assert False + + return node.type + + def visitLet(self, node: AST.Let, args=None): + node.decl.gamma = dict(node.gamma) + eType = self.visit(node.decl) + + node.name.gamma = {node.name.name: eType} + self.visit(node.name) + + node.expr.gamma = dict(node.gamma) + node.expr.gamma[node.name.name] = eType + fType = self.visit(node.expr) + + node.type = copy.copy(fType) + return node.type + + def visitUninterpFuncCall(self, node: AST.UninterpFuncCall, args=None): + # Assert that outputShape and inputDims are lists of int astNode. + assert len(node.argsList) > 0 + isSecret = False + taint = Taints.PUBLIC_C + for curArg in node.argsList: + curArg.gamma = dict(node.gamma) + eType = self.visit( + curArg + ) # This should set the type of each of the input nodes + isSecret = isSecret | eType.isSecret + taint = getTaint_taint(taint, eType.taint) + outputShape = node.outputShape + node.type = Tensor(outputShape, isSecret=isSecret, taint=taint) + return node.type + + def visitArgMax(self, node: AST.ArgMax, args=None): + node.expr.gamma = dict(node.gamma) + eType = self.visit(node.expr) + + node.dim.gamma = dict(node.gamma) + dimType = self.visit(node.dim) + assert isInt(dimType) or (isTensor(dimType) and (len(dimType.shape) == 0)) + + node.type = Tensor(node.outputShape, eType.bitlen, eType.isSecret, eType.taint) + return node.type + + def visitReduce(self, node: AST.Reduce, args=None): + cur_gamma = dict(node.gamma) + node.expr.gamma = cur_gamma + eType = self.visit(node.expr) + + node.type = Tensor(node.outShape, eType.bitlen, eType.isSecret, eType.taint) + return node.type + + def visitInput(self, node: AST.Input, args=None): + node.type = Tensor( + node.shape, isSecret=node.isSecret, taint=Taints[node.inputByParty.name] + ) + return node.type + + def visitFusedBatchNorm(self, node: AST.FusedBatchNorm, args=None): + cur_gamma = dict(node.gamma) + node.expr.gamma = cur_gamma + node.multExpr.gamma = cur_gamma + node.addExpr.gamma = cur_gamma + + exprType = self.visit(node.expr) + multExprType = self.visit(node.multExpr) + addExprType = self.visit(node.addExpr) + + assert len(multExprType.shape) == 1 + assert len(addExprType.shape) == 1 + + [C1] = multExprType.shape + [C2] = addExprType.shape + + assert exprType.shape[-1] == C1 and C1 == C2 + + taint = getTaint_taint(exprType.taint, multExprType.taint) + taint = getTaint_taint(taint, addExprType.taint) + + node.type = copy.copy(exprType) + node.type.taint = taint + return node.type diff --git a/Athos/SeeDot/Util.py b/Athos/SeeDot/Util.py index 0b729185..6381fe92 100644 --- a/Athos/SeeDot/Util.py +++ b/Athos/SeeDot/Util.py @@ -1,4 +1,4 @@ -''' +""" Authors: Sridhar Gopinath, Nishant Kumar. @@ -20,67 +20,77 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import os import _pickle as pickle # Target word length. + class Version: - Fixed = "fixed" - Float = "float" - All = [Fixed, Float] + Fixed = "fixed" + Float = "float" + All = [Fixed, Float] + class Target: - EzPC = "ezpc" - All = [EzPC] + EzPC = "ezpc" + All = [EzPC] + class SFType: - Constant = "constant" - Variable = "variable" - All = [Constant, Variable] + Constant = "constant" + Variable = "variable" + All = [Constant, Variable] + class Config: - version = None - target = None - sfType = None - astFile = None - consSF = None - outputFileName = None - printASTBool = None - wordLength = None - actualWordLength = None - disableRMO = None - disableLivenessOpti = None - disableAllOpti = None - debugOnnx = None + version = None + target = None + sfType = None + astFile = None + consSF = None + outputFileName = None + printASTBool = None + wordLength = None + actualWordLength = None + disableRMO = None + disableLivenessOpti = None + disableAllOpti = None + debugOnnx = None + ###### Helper functions ###### def loadASTFromFile(): - return Config.astFile is not None + return Config.astFile is not None + def forEzPC(): - return Config.target == Target.EzPC + return Config.target == Target.EzPC + + +def copy_dict(dict_src: dict, diff={}): + dict_res = dict(dict_src) + dict_res.update(diff) + return dict_res -def copy_dict(dict_src:dict, diff={}): - dict_res = dict(dict_src) - dict_res.update(diff) - return dict_res # z = [y1,y2,..] = [[x1,..], [x2,..], ..] --> [x1,.., x2,.., ..] -def flatten(z:list): - return [x for y in z for x in y] +def flatten(z: list): + return [x for y in z for x in y] + def write_debug_info(name_mapping): - if not os.path.exists('debug'): - os.makedirs('debug') + if not os.path.exists("debug"): + os.makedirs("debug") - with open('debug/seedot_ezpc_name_map.pkl', 'wb') as f: - pickle.dump(name_mapping, f) + with open("debug/seedot_ezpc_name_map.pkl", "wb") as f: + pickle.dump(name_mapping, f) + + with open("debug/seedot_ezpc_name_map.txt", "w") as f: + for val in name_mapping: + f.write(val + " " + name_mapping[val] + "\n") - with open('debug/seedot_ezpc_name_map.txt', 'w') as f: - for val in name_mapping: - f.write(val + ' ' + name_mapping[val] + '\n') # Broadcasting Rules: # A (4d array): 8 x 1 x 6 x 1 @@ -92,132 +102,138 @@ def write_debug_info(name_mapping): # Result shape: [8, 7, 6, 5] # # If input is a scalar, pass shape as [] -def getBroadcastShapes(Shape1 : list, Shape2 : list): - #Broadcast rules apply in reverse direction - shape1 = Shape1[::-1] - shape2 = Shape2[::-1] - len1 = len(shape1) - len2 = len(shape2) - outputshape = [] - swapped = False - if len1 != len2: - if len1 > len2: - len1, len2 = len2, len1 - shape1, shape2 = shape2, shape1 - swapped = True - assert len1 < len2 - - broadcastMask1 = [False] * len1 - broadcastMask2 = [False] * len2 - - for i in range(len2): - length = 0 - if i >= len1: - #broadcastMask1[i] = True - outputshape.append(shape2[i]) - continue - if shape1[i] != shape2[i]: - if shape1[i] == 1: - outputshape.append(shape2[i]) - broadcastMask1[i] = True - elif shape2[i] == 1: - outputshape.append(shape1[i]) - broadcastMask2[i] = True - else: - print("Dimension no. {} has a mismatch of length.".format(len2 - i)) - assert False, "Cannot broadcast. Program is malformed. Atleast one length should have been 1. i1: {} i2: {}".format(shape1[i], shape2[i]) - else: - outputshape.append(shape1[i]) - - if swapped: - broadcastMask1, broadcastMask2 = broadcastMask2, broadcastMask1 - - outputshape.reverse() - broadcastMask1.reverse() - broadcastMask2.reverse() - return outputshape, broadcastMask1, broadcastMask2 +def getBroadcastShapes(Shape1: list, Shape2: list): + # Broadcast rules apply in reverse direction + shape1 = Shape1[::-1] + shape2 = Shape2[::-1] + len1 = len(shape1) + len2 = len(shape2) + outputshape = [] + swapped = False + if len1 != len2: + if len1 > len2: + len1, len2 = len2, len1 + shape1, shape2 = shape2, shape1 + swapped = True + assert len1 < len2 + + broadcastMask1 = [False] * len1 + broadcastMask2 = [False] * len2 + + for i in range(len2): + length = 0 + if i >= len1: + # broadcastMask1[i] = True + outputshape.append(shape2[i]) + continue + if shape1[i] != shape2[i]: + if shape1[i] == 1: + outputshape.append(shape2[i]) + broadcastMask1[i] = True + elif shape2[i] == 1: + outputshape.append(shape1[i]) + broadcastMask2[i] = True + else: + print("Dimension no. {} has a mismatch of length.".format(len2 - i)) + assert ( + False + ), "Cannot broadcast. Program is malformed. Atleast one length should have been 1. i1: {} i2: {}".format( + shape1[i], shape2[i] + ) + else: + outputshape.append(shape1[i]) + + if swapped: + broadcastMask1, broadcastMask2 = broadcastMask2, broadcastMask1 + + outputshape.reverse() + broadcastMask1.reverse() + broadcastMask2.reverse() + return outputshape, broadcastMask1, broadcastMask2 + def get_volume(shape: list): - vol = 1 - for i in shape: - vol = vol * i - return vol + vol = 1 + for i in shape: + vol = vol * i + return vol + class DisjointSet: - class Node: - def __init__(self): - self.parent = self - self.children = [] - - def get_root(self): - if (self.parent != self): - old_parent = self.parent - self.parent = self.parent.get_root() - if self.parent != old_parent: - self.parent.children.append(self) - old_parent.children.remove(self) - return self.parent - else: - return self - - def get_all_children(self): - all_children = [] - all_children.extend(self.children) - tmp = [] - for i in all_children: - tmp.extend(i.get_all_children()) - all_children.extend(tmp) - return all_children - - def __init__(self): - self.key_to_node = {} - self.node_to_key = {} - - def inSet(self, inp): - return inp in self.key_to_node - - def make_set(self, inp): - if self.inSet(inp): - return - n = self.Node() - self.key_to_node[inp] = n - self.node_to_key[n] = inp - - def union(self, inp1, inp2): - n1 = self.key_to_node[inp1] - n2 = self.key_to_node[inp2] - r1 = n1.get_root() - r2 = n2.get_root() - if (r1 != r2): - r1.parent = r2 - r2.children.append(r1) - - def find(self, inp): - if not self.inSet(inp): - return None - return self.key_to_node[inp].get_root() - - def find_key(self, inp): - node = self.find(inp) - if node is None: - return None - return self.node_to_key[node] - - def get_set(self, inp): - if not self.inSet(inp): - return None - n = self.key_to_node[inp].get_root() - return [n] + n.get_all_children() - - def get_key_set(self, inp): - nodes = self.get_set(inp) - if nodes is None: - return None - return [self.node_to_key[i] for i in nodes] - - def print(self): - print(self.key_to_node) - print(self.node_to_key) - - def print_set(self, inp): - print(self.get_key_set(inp)) \ No newline at end of file + class Node: + def __init__(self): + self.parent = self + self.children = [] + + def get_root(self): + if self.parent != self: + old_parent = self.parent + self.parent = self.parent.get_root() + if self.parent != old_parent: + self.parent.children.append(self) + old_parent.children.remove(self) + return self.parent + else: + return self + + def get_all_children(self): + all_children = [] + all_children.extend(self.children) + tmp = [] + for i in all_children: + tmp.extend(i.get_all_children()) + all_children.extend(tmp) + return all_children + + def __init__(self): + self.key_to_node = {} + self.node_to_key = {} + + def inSet(self, inp): + return inp in self.key_to_node + + def make_set(self, inp): + if self.inSet(inp): + return + n = self.Node() + self.key_to_node[inp] = n + self.node_to_key[n] = inp + + def union(self, inp1, inp2): + n1 = self.key_to_node[inp1] + n2 = self.key_to_node[inp2] + r1 = n1.get_root() + r2 = n2.get_root() + if r1 != r2: + r1.parent = r2 + r2.children.append(r1) + + def find(self, inp): + if not self.inSet(inp): + return None + return self.key_to_node[inp].get_root() + + def find_key(self, inp): + node = self.find(inp) + if node is None: + return None + return self.node_to_key[node] + + def get_set(self, inp): + if not self.inSet(inp): + return None + n = self.key_to_node[inp].get_root() + return [n] + n.get_all_children() + + def get_key_set(self, inp): + nodes = self.get_set(inp) + if nodes is None: + return None + return [self.node_to_key[i] for i in nodes] + + def print(self): + print(self.key_to_node) + print(self.node_to_key) + + def print_set(self, inp): + print(self.get_key_set(inp)) diff --git a/Athos/SeeDot/Writer.py b/Athos/SeeDot/Writer.py index 109b9a42..035e051b 100644 --- a/Athos/SeeDot/Writer.py +++ b/Athos/SeeDot/Writer.py @@ -1,4 +1,4 @@ -''' +""" Authors: Sridhar Gopinath, Nishant Kumar. @@ -20,23 +20,24 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" + class Writer: - def __init__(self, fileName): - self.file = open(fileName, 'w') - self.indentLevel = 0 + def __init__(self, fileName): + self.file = open(fileName, "w") + self.indentLevel = 0 - def printf(self, str, *args, indent=False): - if indent: - self.file.write('\t' * self.indentLevel) - self.file.write(str % args) + def printf(self, str, *args, indent=False): + if indent: + self.file.write("\t" * self.indentLevel) + self.file.write(str % args) - def increaseIndent(self): - self.indentLevel += 1 + def increaseIndent(self): + self.indentLevel += 1 - def decreaseIndent(self): - self.indentLevel -= 1 + def decreaseIndent(self): + self.indentLevel -= 1 - def close(self): - self.file.close() + def close(self): + self.file.close() diff --git a/Athos/TFCompiler/DumpTFMtData.py b/Athos/TFCompiler/DumpTFMtData.py index 40b151e0..c6afb432 100644 --- a/Athos/TFCompiler/DumpTFMtData.py +++ b/Athos/TFCompiler/DumpTFMtData.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,181 +20,228 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import numpy import tensorflow as tf from tensorflow.tools.graph_transforms import TransformGraph + def strip_variable_init_constants(graph_def, input_tensor_names, output_tensor_names): - transforms = [ - 'remove_nodes(op=Identity)', - 'strip_unused_nodes', - ] - # Sanity check if output/input nodes were constant and replaced with variables. - all_node_names = set([i.name for i in graph_def.node]) - def get_true_names(tensor_names, all_nodes): - real_names = [] - for i in tensor_names: - if i not in all_nodes: - var_name = i + "_mpc_const_var" - if var_name in all_nodes: - real_names.append(var_name) - else: - real_names.append(i) - return real_names - real_input_names = get_true_names(input_tensor_names, all_node_names) - real_output_names = get_true_names(output_tensor_names, all_node_names) - optimized_graph_def = TransformGraph(graph_def, real_input_names, real_output_names, transforms) - return optimized_graph_def + transforms = [ + "remove_nodes(op=Identity)", + "strip_unused_nodes", + ] + # Sanity check if output/input nodes were constant and replaced with variables. + all_node_names = set([i.name for i in graph_def.node]) + + def get_true_names(tensor_names, all_nodes): + real_names = [] + for i in tensor_names: + if i not in all_nodes: + var_name = i + "_mpc_const_var" + if var_name in all_nodes: + real_names.append(var_name) + else: + real_names.append(i) + return real_names + + real_input_names = get_true_names(input_tensor_names, all_node_names) + real_output_names = get_true_names(output_tensor_names, all_node_names) + optimized_graph_def = TransformGraph( + graph_def, real_input_names, real_output_names, transforms + ) + return optimized_graph_def + def save_graphdef(graph_def): - with open('./graphDef.mtdata', 'w') as f: - f.write(str(graph_def)) + with open("./graphDef.mtdata", "w") as f: + f.write(str(graph_def)) + def save_sizeinfo(optimized_graph_def, sess, feed_dict): - # Save size information for tensors on which output depends - tensors_to_evaluate = [] - tensors_to_evaluate_names = [] - graph = sess.graph - for node in optimized_graph_def.node: - output_number = 0 - for cur_output in graph.get_operation_by_name(node.name).outputs: - tensors_to_evaluate.append(cur_output) - if output_number == 0: - tensor_name = node.name - else: - tensor_name = cur_output.name - tensors_to_evaluate_names.append(tensor_name) - output_number += 1 - tensors_evaluated = sess.run(tensors_to_evaluate, feed_dict) - tensors_shape = list(map(lambda x : x.shape, tensors_evaluated)) - - # Write size info in a file - with open('./sizeInfo.mtdata','w') as f: - for ii, curr in enumerate(tensors_to_evaluate_names): - curShape = tensors_shape[ii] - f.write(tensors_to_evaluate_names[ii] + ' ') - for dim in curShape: - f.write(str(dim)+' ') - f.write('\n') - + # Save size information for tensors on which output depends + tensors_to_evaluate = [] + tensors_to_evaluate_names = [] + graph = sess.graph + for node in optimized_graph_def.node: + output_number = 0 + for cur_output in graph.get_operation_by_name(node.name).outputs: + tensors_to_evaluate.append(cur_output) + if output_number == 0: + tensor_name = node.name + else: + tensor_name = cur_output.name + tensors_to_evaluate_names.append(tensor_name) + output_number += 1 + tensors_evaluated = sess.run(tensors_to_evaluate, feed_dict) + tensors_shape = list(map(lambda x: x.shape, tensors_evaluated)) + + # Write size info in a file + with open("./sizeInfo.mtdata", "w") as f: + for ii, curr in enumerate(tensors_to_evaluate_names): + curShape = tensors_shape[ii] + f.write(tensors_to_evaluate_names[ii] + " ") + for dim in curShape: + f.write(str(dim) + " ") + f.write("\n") + + def save_graph_metadata(output_tensor, sess, feed_dict): - graph_def = sess.graph_def - transforms = [ - 'remove_nodes(op=Identity)', - 'strip_unused_nodes', - 'fold_batch_norms', - 'fold_constants(ignore_errors=true)' - ] - optimized_graph_def = TransformGraph(graph_def, [], [output_tensor.name], transforms) - with open('./graphDef.mtdata', 'w') as f: - f.write(str(optimized_graph_def)) - - # Save size information for tensors on which output depends - tensors_to_evaluate = [] - tensors_to_evaluate_names = [] - graph = sess.graph - for node in optimized_graph_def.node: - output_number = 0 - for cur_output in graph.get_operation_by_name(node.name).outputs: - tensors_to_evaluate.append(cur_output) - if output_number == 0: - tensor_name = node.name - else: - tensor_name = cur_output.name - tensors_to_evaluate_names.append(tensor_name) - output_number += 1 - tensors_evaluated = sess.run(tensors_to_evaluate, feed_dict) - tensors_shape = list(map(lambda x : x.shape, tensors_evaluated)) - - # Write size info in a file - with open('./sizeInfo.mtdata','w') as f: - for ii, curr in enumerate(tensors_to_evaluate_names): - curShape = tensors_shape[ii] - f.write(tensors_to_evaluate_names[ii] + ' ') - for dim in curShape: - f.write(str(dim)+' ') - f.write('\n') - - return optimized_graph_def + graph_def = sess.graph_def + transforms = [ + "remove_nodes(op=Identity)", + "strip_unused_nodes", + "fold_batch_norms", + "fold_constants(ignore_errors=true)", + ] + optimized_graph_def = TransformGraph( + graph_def, [], [output_tensor.name], transforms + ) + with open("./graphDef.mtdata", "w") as f: + f.write(str(optimized_graph_def)) + + # Save size information for tensors on which output depends + tensors_to_evaluate = [] + tensors_to_evaluate_names = [] + graph = sess.graph + for node in optimized_graph_def.node: + output_number = 0 + for cur_output in graph.get_operation_by_name(node.name).outputs: + tensors_to_evaluate.append(cur_output) + if output_number == 0: + tensor_name = node.name + else: + tensor_name = cur_output.name + tensors_to_evaluate_names.append(tensor_name) + output_number += 1 + tensors_evaluated = sess.run(tensors_to_evaluate, feed_dict) + tensors_shape = list(map(lambda x: x.shape, tensors_evaluated)) + + # Write size info in a file + with open("./sizeInfo.mtdata", "w") as f: + for ii, curr in enumerate(tensors_to_evaluate_names): + curShape = tensors_shape[ii] + f.write(tensors_to_evaluate_names[ii] + " ") + for dim in curShape: + f.write(str(dim) + " ") + f.write("\n") + + return optimized_graph_def + def updateWeightsForBN(optimized_graph_def, sess, feed_dict={}): - graph = sess.graph - graphDef = optimized_graph_def + graph = sess.graph + graphDef = optimized_graph_def + + for node in graphDef.node: + if node.op == "FusedBatchNorm" or node.op == "FusedBatchNormV3": + gamma = graph.get_operation_by_name(node.input[1]).outputs[0] + beta = graph.get_operation_by_name(node.input[2]).outputs[0] + mu = graph.get_operation_by_name(node.input[3]).outputs[0] + variance = graph.get_operation_by_name(node.input[4]).outputs[0] - for node in graphDef.node: - if (node.op == 'FusedBatchNorm' or node.op == 'FusedBatchNormV3'): - gamma = graph.get_operation_by_name(node.input[1]).outputs[0] - beta = graph.get_operation_by_name(node.input[2]).outputs[0] - mu = graph.get_operation_by_name(node.input[3]).outputs[0] - variance = graph.get_operation_by_name(node.input[4]).outputs[0] + epsilon = node.attr["epsilon"].f + rsigma = tf.rsqrt(variance + epsilon) - epsilon = node.attr['epsilon'].f - rsigma = tf.rsqrt(variance + epsilon) + sess.run(tf.assign(gamma, gamma * rsigma), feed_dict) + sess.run(tf.assign(beta, beta - gamma * mu), feed_dict) + sess.run(tf.assign(mu, tf.zeros(tf.shape(mu))), feed_dict) + sess.run( + tf.assign(variance, tf.fill(tf.shape(variance), 1 - epsilon)), feed_dict + ) - sess.run(tf.assign(gamma, gamma*rsigma), feed_dict) - sess.run(tf.assign(beta, beta - gamma*mu), feed_dict) - sess.run(tf.assign(mu, tf.zeros(tf.shape(mu))), feed_dict) - sess.run(tf.assign(variance, tf.fill(tf.shape(variance), 1-epsilon)), feed_dict) def dumpImageDataInt(imgData, filename, scalingFac, writeMode): - print("Dumping image data...") - with open(filename, writeMode) as ff: - for xx in numpy.nditer(imgData, order='C'): - ff.write(str(int(xx * (1<=0 and n1<=9 and n2>=0 and n2<=9 and n3>=0 and n3<=9): - self.__tensorBytes[byteArrIdx] = (n1)*64 + (n2)*8 + (n3) + n1 = ord(self.__tensorContentInput[tsCtnIdx + 1]) - ord("0") + n2 = ord(self.__tensorContentInput[tsCtnIdx + 2]) - ord("0") + n3 = ord(self.__tensorContentInput[tsCtnIdx + 3]) - ord("0") + if ( + n1 >= 0 + and n1 <= 9 + and n2 >= 0 + and n2 <= 9 + and n3 >= 0 + and n3 <= 9 + ): + self.__tensorBytes[byteArrIdx] = (n1) * 64 + (n2) * 8 + (n3) tsCtnIdx += 3 else: - self.__tensorBytes[byteArrIdx] = ord(self.__tensorContentInput[tsCtnIdx:tsCtnIdx+2].encode('latin-1').decode('unicode_escape')) - tsCtnIdx += 1 + self.__tensorBytes[byteArrIdx] = ord( + self.__tensorContentInput[tsCtnIdx : tsCtnIdx + 2] + .encode("latin-1") + .decode("unicode_escape") + ) + tsCtnIdx += 1 byteArrIdx += 1 tsCtnIdx += 1 @@ -230,45 +288,65 @@ def readFromFilePointer(self, fileP, cnt): cnt += 1 while line: tokens = line.split() - if (errIfTokensNotMinLen(tokens, 1, cnt, "Tensor")): return (False, cnt) + if errIfTokensNotMinLen(tokens, 1, cnt, "Tensor"): + return (False, cnt) curToken = tokens[0] - if (curToken == "}"): + if curToken == "}": return (True, cnt) - elif (curToken == "tensor_shape"): + elif curToken == "tensor_shape": sh = Shape() (noParseErr, cnt) = sh.readFromFilePointer(fileP, cnt) - if not(noParseErr): - print("Error in reading shape while parsing tensor at line =", cnt, file=sys.stderr) + if not (noParseErr): + print( + "Error in reading shape while parsing tensor at line =", + cnt, + file=sys.stderr, + ) return (False, cnt) self.__tensorShape = sh - if (len(self.__tensorShape.getDimRef()) == 0): + if len(self.__tensorShape.getDimRef()) == 0: self.__totalSize = 0 - elif (curToken == "dtype:"): - if (errIfTokensNotMinLen(tokens, 2, cnt, "Tensor")): return (False, cnt) + elif curToken == "dtype:": + if errIfTokensNotMinLen(tokens, 2, cnt, "Tensor"): + return (False, cnt) dtype = DataTypeEnum.Parse(tokens[1]) - if (dtype == DataTypeEnum.DT_INVALID): - print("Unknown dtype found while parsing Tensor at line =", cnt, file=sys.stderr) + if dtype == DataTypeEnum.DT_INVALID: + print( + "Unknown dtype found while parsing Tensor at line =", + cnt, + file=sys.stderr, + ) return (False, cnt) else: self.__dtype = dtype - elif (curToken == "tensor_content:"): - if (errIfTokensNotMinLen(tokens, 2, cnt, "Tensor")): return (False, cnt) + elif curToken == "tensor_content:": + if errIfTokensNotMinLen(tokens, 2, cnt, "Tensor"): + return (False, cnt) self.__tensorContentInput = tokens[1] self.__convToBytes() - elif (curToken == "float_val:"): - if (errIfTokensNotMinLen(tokens, 2, cnt, "Tensor")): return (False, cnt) + elif curToken == "float_val:": + if errIfTokensNotMinLen(tokens, 2, cnt, "Tensor"): + return (False, cnt) self.__valInput = float(tokens[1]) self.__convToBytes() - elif (curToken == "bool_val:"): - if (errIfTokensNotMinLen(tokens, 2, cnt, "Tensor")): return (False, cnt) + elif curToken == "bool_val:": + if errIfTokensNotMinLen(tokens, 2, cnt, "Tensor"): + return (False, cnt) self.__valInput = bool(tokens[1]) self.__convToBytes() - elif (curToken == "int_val:"): - if (errIfTokensNotMinLen(tokens, 2, cnt, "Tensor")): return (False, cnt) + elif curToken == "int_val:": + if errIfTokensNotMinLen(tokens, 2, cnt, "Tensor"): + return (False, cnt) self.__valInput = int(tokens[1]) self.__convToBytes() else: - print("Unknown token found while parsing Tensor at line =", cnt, ", token =", curToken, file=sys.stderr) + print( + "Unknown token found while parsing Tensor at line =", + cnt, + ", token =", + curToken, + file=sys.stderr, + ) return (False, cnt) line = fileP.readline() cnt += 1 @@ -281,7 +359,7 @@ def print(self): print("Content:", self.__tensorContentInput) print("ShapeRank:", self.__tensorShape.getRank()) print("TotalSizeBytes:", self.__totalSize) - if (self.__tensorBytes): + if self.__tensorBytes: print("ActualContentBytes:", self.__tensorBytes) else: print("ValArr:", self.__valArr) @@ -302,7 +380,7 @@ def getContentAsValArr(self): # else: # # Right now the CNN tensorflow benchmark i am dealing with only has int32 case when tensorContents are given as bytes. # # If in future, we encounter this case for float/bool, deal with it accordingly here - # # Plus, also from empirical observation, byteorder is little and its a signed value for ints. + # # Plus, also from empirical observation, byteorder is little and its a signed value for ints. # # Figure out for others when the time comes. # print(self.__dtype) # assert False @@ -314,35 +392,36 @@ def getContentAsValArr(self): # it += numOfBytesPerVal # self.__valArr = returnArr if self.__dtype == DataTypeEnum.DT_FLOAT: - dtype = numpy.dtype('0: - assert(all((type(x) is int) for x in self.__valIntLi)) + if len(self.__valIntLi) > 0: + assert all((type(x) is int) for x in self.__valIntLi) return self.__valIntLi + class Value: def __init__(self): self.__val = None @@ -398,117 +489,141 @@ def readFromFilePointer(self, fileP, cnt): cnt += 1 while line: tokens = line.split() - if (errIfTokensNotMinLen(tokens, 1, cnt, "Value")): return (False, cnt) + if errIfTokensNotMinLen(tokens, 1, cnt, "Value"): + return (False, cnt) curToken = tokens[0] - if (curToken == "}"): + if curToken == "}": return (True, cnt) - elif (curToken == "s:"): - if (errIfTokensNotMinLen(tokens, 2, cnt, "Value")): return (False, cnt) + elif curToken == "s:": + if errIfTokensNotMinLen(tokens, 2, cnt, "Value"): + return (False, cnt) self.__val = tokens[1][1:-1] - elif (curToken == "i:"): - if (errIfTokensNotMinLen(tokens, 2, cnt, "Value")): return (False, cnt) + elif curToken == "i:": + if errIfTokensNotMinLen(tokens, 2, cnt, "Value"): + return (False, cnt) self.__val = int(tokens[1]) - elif (curToken == "f:"): - if (errIfTokensNotMinLen(tokens, 2, cnt, "Value")): return (False, cnt) + elif curToken == "f:": + if errIfTokensNotMinLen(tokens, 2, cnt, "Value"): + return (False, cnt) self.__val = float(tokens[1]) - elif (curToken == "b:"): - if (errIfTokensNotMinLen(tokens, 2, cnt, "Value")): return (False, cnt) + elif curToken == "b:": + if errIfTokensNotMinLen(tokens, 2, cnt, "Value"): + return (False, cnt) self.__val = bool(tokens[1] == "true") - elif (curToken == "type:"): - if (errIfTokensNotMinLen(tokens, 2, cnt, "Value")): return (False, cnt) + elif curToken == "type:": + if errIfTokensNotMinLen(tokens, 2, cnt, "Value"): + return (False, cnt) dtype = DataTypeEnum.Parse(tokens[1]) - if (dtype == DataTypeEnum.DT_INVALID): - print("Invalid dtype found while parsing Value at line =", cnt, file=sys.stderr) + if dtype == DataTypeEnum.DT_INVALID: + print( + "Invalid dtype found while parsing Value at line =", + cnt, + file=sys.stderr, + ) return (False, cnt) else: self.__val = dtype - elif (curToken == "shape"): + elif curToken == "shape": sh = Shape() (noParseError, cnt) = sh.readFromFilePointer(fileP, cnt) - if (not(noParseError)): + if not (noParseError): print("Error in parsing Value at line =", cnt, file=sys.stderr) return (False, cnt) self.__val = sh - elif (curToken == "list"): + elif curToken == "list": mv = MultiValue() (noParseError, cnt) = mv.readFromFilePointer(fileP, cnt) - if (not(noParseError)): + if not (noParseError): print("Error in parsing Value at line =", cnt, file=sys.stderr) return (False, cnt) self.__val = mv - elif (curToken == "tensor"): + elif curToken == "tensor": ts = Tensor() (noParseError, cnt) = ts.readFromFilePointer(fileP, cnt) - if (not(noParseError)): + if not (noParseError): print("Error in parsing Value at line =", cnt, file=sys.stderr) return (False, cnt) self.__val = ts else: - print("Unknown token while parsing Value at line =", cnt, ", token =", curToken, file=sys.stderr) + print( + "Unknown token while parsing Value at line =", + cnt, + ", token =", + curToken, + file=sys.stderr, + ) return (False, cnt) line = fileP.readline() cnt += 1 return (False, cnt) def print(self): - if (type(self.__val) is str): print("s:", self.__val) - elif (type(self.__val) is int): print("i:", self.__val) - elif (type(self.__val) is float): print("f:", self.__val) - elif (type(self.__val) is bool): print("b:", self.__val) - elif (type(self.__val) is DataTypeEnum): print("Type:", self.__val) - elif (type(self.__val) is Shape): + if type(self.__val) is str: + print("s:", self.__val) + elif type(self.__val) is int: + print("i:", self.__val) + elif type(self.__val) is float: + print("f:", self.__val) + elif type(self.__val) is bool: + print("b:", self.__val) + elif type(self.__val) is DataTypeEnum: + print("Type:", self.__val) + elif type(self.__val) is Shape: print("Shape: ", end="") self.__val.print() - elif (type(self.__val) is Tensor): + elif type(self.__val) is Tensor: print("Tensor: ", end="") self.__val.print() - elif (type(self.__val) is MultiValue): + elif type(self.__val) is MultiValue: print("List: ", end="") self.__val.print() else: - assert(False) + assert False def getS(self): - assert(type(self.__val) is str) + assert type(self.__val) is str return self.__val def getI(self): - assert(type(self.__val) is int) + assert type(self.__val) is int return self.__val def getF(self): - assert(type(self.__val) is float) + assert type(self.__val) is float return self.__val def getB(self): - assert(type(self.__val) is bool) + assert type(self.__val) is bool return self.__val def getDataType(self): - assert(type(self.__val) is DataTypeEnum) + assert type(self.__val) is DataTypeEnum return self.__val def getShape(self): - assert(type(self.__val) is Shape) + assert type(self.__val) is Shape return self.__val def getTensor(self): - assert(type(self.__val) is Tensor) + assert type(self.__val) is Tensor return self.__val def getList(self): - assert(type(self.__val) is MultiValue) + assert type(self.__val) is MultiValue return self.__val + class Node: def __init__(self, op="", inputs=None, name=""): - self.__name = name #Name of node - self.__op = op #Name of operation carried out by node + self.__name = name # Name of node + self.__op = op # Name of operation carried out by node if inputs is None: - self.__inputs = [] #List of all inputs to the current node + self.__inputs = [] # List of all inputs to the current node else: self.__inputs = inputs - self.__attr = {} #Map of (attrName, Value) of all attributes for the current node + self.__attr = ( + {} + ) # Map of (attrName, Value) of all attributes for the current node def getName(self): return self.__name @@ -534,29 +649,49 @@ def readAttrFromFilePointer(self, fileP, cnt): keyStr = None while line: tokens = line.split() - if (errIfTokensNotMinLen(tokens, 1, cnt, "attr from node")): return (False, cnt) + if errIfTokensNotMinLen(tokens, 1, cnt, "attr from node"): + return (False, cnt) curToken = tokens[0] - if (curToken == "}"): + if curToken == "}": return (True, cnt) - elif (curToken == "key:"): - if (errIfTokensNotMinLen(tokens, 2, cnt, "attr from node")): return (False, cnt) - if (keyStr): - #keyStr is already non-None .. there is then probably some error - print("Too many keys found while parsing attr for node at line =", cnt, file=sys.stderr) + elif curToken == "key:": + if errIfTokensNotMinLen(tokens, 2, cnt, "attr from node"): + return (False, cnt) + if keyStr: + # keyStr is already non-None .. there is then probably some error + print( + "Too many keys found while parsing attr for node at line =", + cnt, + file=sys.stderr, + ) return (False, cnt) keyStr = tokens[1][1:-1] - elif (curToken == "value"): + elif curToken == "value": curVal = Value() (noParseError, cnt) = curVal.readFromFilePointer(fileP, cnt) - if not(noParseError): - print("Error while parsing value of attr for node at line =", cnt, file=sys.stderr) + if not (noParseError): + print( + "Error while parsing value of attr for node at line =", + cnt, + file=sys.stderr, + ) return (False, cnt) - if not(keyStr): - print("Value found - but no key found for attr in node at line =", cnt, file=sys.stderr) + if not (keyStr): + print( + "Value found - but no key found for attr in node at line =", + cnt, + file=sys.stderr, + ) return (False, cnt) self.__attr[keyStr] = curVal else: - print("Unrecognized token found while parsing attribute for node at line =", cnt, ", token =", curToken, file=sys.stderr) + print( + "Unrecognized token found while parsing attribute for node at line =", + cnt, + ", token =", + curToken, + file=sys.stderr, + ) return (False, cnt) line = fileP.readline() cnt += 1 @@ -567,36 +702,46 @@ def readFromFilePointer(self, fileP, cnt): cnt += 1 while line: tokens = line.split() - if (errIfTokensNotMinLen(tokens, 1, cnt, "node")): return (False, cnt) + if errIfTokensNotMinLen(tokens, 1, cnt, "node"): + return (False, cnt) curToken = tokens[0] - if (curToken == "}"): + if curToken == "}": return (True, cnt) - elif (curToken == "name:"): - if (errIfTokensNotMinLen(tokens, 2, cnt, "node")): return (False, cnt) + elif curToken == "name:": + if errIfTokensNotMinLen(tokens, 2, cnt, "node"): + return (False, cnt) self.__name = tokens[1][1:-1] - elif (curToken == "op:"): - if (errIfTokensNotMinLen(tokens, 2, cnt, "node")): return (False, cnt) + elif curToken == "op:": + if errIfTokensNotMinLen(tokens, 2, cnt, "node"): + return (False, cnt) self.__op = tokens[1][1:-1] - elif (curToken == "input:"): - if (errIfTokensNotMinLen(tokens, 2, cnt, "node")): return (False, cnt) + elif curToken == "input:": + if errIfTokensNotMinLen(tokens, 2, cnt, "node"): + return (False, cnt) input_name = tokens[1][1:-1] # Sometimes graph defs generated specify 0'th output explicitly whereas the node names do not # contain that. So we strip it if input_name.endswith(":0"): input_name = input_name[:-2] self.__inputs.append(input_name) - elif (curToken == "attr"): + elif curToken == "attr": (noParseError, cnt) = self.readAttrFromFilePointer(fileP, cnt) - if (not(noParseError)): + if not (noParseError): print("Error parsing node data at line =", cnt, file=sys.stderr) return (False, cnt) else: - print("Unrecognized token found while parsing node data at line =", cnt, ", token =", curToken, file=sys.stderr) + print( + "Unrecognized token found while parsing node data at line =", + cnt, + ", token =", + curToken, + file=sys.stderr, + ) return (False, cnt) line = fileP.readline() cnt += 1 return (False, cnt) - + def print(self): print("NODE::") print(self.__name, ",", self.__op) @@ -607,10 +752,13 @@ def print(self): print("Attr:", attrKey) attrVal.print() + class Graph: def __init__(self): - self.__Nodes = {} # Map of (op, Node) - self.__NodesLi = [] # Sequential list of nodes in the order in which its specified in graph_def. + self.__Nodes = {} # Map of (op, Node) + self.__NodesLi = ( + [] + ) # Sequential list of nodes in the order in which its specified in graph_def. def getAllNodes(self): return self.__Nodes @@ -626,29 +774,44 @@ def readFromFilePointer(self, fileP): cnt = 1 while line: tokens = line.split() - if (errIfTokensNotMinLen(tokens, 1, cnt, "graph")): return False + if errIfTokensNotMinLen(tokens, 1, cnt, "graph"): + return False curToken = tokens[0] - if (curToken == "node"): + if curToken == "node": curNode = Node() (noPaseError, cnt) = curNode.readFromFilePointer(fileP, cnt) - if (noPaseError): + if noPaseError: self.__Nodes[curNode.getName()] = curNode self.__NodesLi.append(curNode) else: - print("Error parsing graph dump for node at line =", cnt, file=sys.stderr) + print( + "Error parsing graph dump for node at line =", + cnt, + file=sys.stderr, + ) return False - elif (curToken == "}"): - #CurNode ended + elif curToken == "}": + # CurNode ended pass - elif (curToken == "versions" or curToken == "library"): - print("Versions/Library node found. Ignoring remainder graph. Line =", cnt, file=sys.stderr) + elif curToken == "versions" or curToken == "library": + print( + "Versions/Library node found. Ignoring remainder graph. Line =", + cnt, + file=sys.stderr, + ) print("Graph parsing successful.", file=sys.stderr) return True else: - print("Unrecognized token in graph dump at line =", cnt, ", token =", curToken, file=sys.stderr) + print( + "Unrecognized token in graph dump at line =", + cnt, + ", token =", + curToken, + file=sys.stderr, + ) return False line = fileP.readline() - cnt+=1 + cnt += 1 print("Graph parsing successful.") return True diff --git a/Athos/TFCompiler/ProcessTFGraph.py b/Athos/TFCompiler/ProcessTFGraph.py index 545cd53d..a6fed45a 100644 --- a/Athos/TFCompiler/ProcessTFGraph.py +++ b/Athos/TFCompiler/ProcessTFGraph.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,249 +20,279 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import os, sys + sys.path.append(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'SeeDot')) #Add SeeDot directory to path +sys.path.append( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "SeeDot") +) # Add SeeDot directory to path import Graph, AST.AST as AST, _pickle as pickle, os from TFNodesAST import TFNodesAST from AST.PrintAST import PrintAST from AST.MtdAST import MtdAST -def checkTFNodeNameForEq(curNodeOp:str, givenOp:str): - return (curNodeOp == "\"" + givenOp + "\"") + +def checkTFNodeNameForEq(curNodeOp: str, givenOp: str): + return curNodeOp == '"' + givenOp + '"' + def generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict): - curNodeOp = curNode.getOp() - ast = None - func = getattr(TFNodesAST, curNodeOp) - (assignedVarAST, curASTs) = func(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict) - return (assignedVarAST, curASTs) + curNodeOp = curNode.getOp() + ast = None + func = getattr(TFNodesAST, curNodeOp) + (assignedVarAST, curASTs) = func( + graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict + ) + return (assignedVarAST, curASTs) + -#Takes the graph DS and outputs IR in SeeDot for the same +# Takes the graph DS and outputs IR in SeeDot for the same def generateIRCode(graph, extraInfoDict): - program = None - innerMostLetASTNode = None - dictNodeNameToOutVarStr = {} - outVarCt = 0 - outVarPrefix = "J" - mtdAST = MtdAST() - for curNode in graph.getAllNodesRef(): - for curInp in curNode.getInputsRef(): - assert(curInp in dictNodeNameToOutVarStr), "input={} expected as input for node={} but not yet processed".format(curInp, curNode.getName()) #Consequence of topological sorting of the TF graph - (assignedVarAST, curAsts) = generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraInfoDict) - for outputName, curAst in curAsts.items(): - mtdForCurAST = {AST.ASTNode.mtdKeyTFOpName : curNode.getOp(), - AST.ASTNode.mtdKeyTFNodeName : outputName} - - if (curAst is None): - dictNodeNameToOutVarStr[outputName] = None - continue - curOutVarStr = outVarPrefix + str(outVarCt) - curOutVarAstNode = (assignedVarAST if assignedVarAST else AST.ID(curOutVarStr)) - if program: - assert(type(innerMostLetASTNode) is AST.Let) - newNode = AST.Let(curOutVarAstNode, curAst, curOutVarAstNode) - mtdAST.visit(newNode, mtdForCurAST) - innerMostLetASTNode.expr = newNode - innerMostLetASTNode = newNode - else: - innerMostLetASTNode = AST.Let(AST.ID(curOutVarStr), curAst, curOutVarAstNode) - mtdAST.visit(innerMostLetASTNode, mtdForCurAST) - innerMostLetASTNode.depth = 0 - program = innerMostLetASTNode - dictNodeNameToOutVarStr[outputName] = curOutVarStr - outVarCt += 1 - return (program, dictNodeNameToOutVarStr) + program = None + innerMostLetASTNode = None + dictNodeNameToOutVarStr = {} + outVarCt = 0 + outVarPrefix = "J" + mtdAST = MtdAST() + for curNode in graph.getAllNodesRef(): + for curInp in curNode.getInputsRef(): + assert ( + curInp in dictNodeNameToOutVarStr + ), "input={} expected as input for node={} but not yet processed".format( + curInp, curNode.getName() + ) # Consequence of topological sorting of the TF graph + (assignedVarAST, curAsts) = generateASTForNode( + graph, curNode, dictNodeNameToOutVarStr, extraInfoDict + ) + for outputName, curAst in curAsts.items(): + mtdForCurAST = { + AST.ASTNode.mtdKeyTFOpName: curNode.getOp(), + AST.ASTNode.mtdKeyTFNodeName: outputName, + } + + if curAst is None: + dictNodeNameToOutVarStr[outputName] = None + continue + curOutVarStr = outVarPrefix + str(outVarCt) + curOutVarAstNode = ( + assignedVarAST if assignedVarAST else AST.ID(curOutVarStr) + ) + if program: + assert type(innerMostLetASTNode) is AST.Let + newNode = AST.Let(curOutVarAstNode, curAst, curOutVarAstNode) + mtdAST.visit(newNode, mtdForCurAST) + innerMostLetASTNode.expr = newNode + innerMostLetASTNode = newNode + else: + innerMostLetASTNode = AST.Let( + AST.ID(curOutVarStr), curAst, curOutVarAstNode + ) + mtdAST.visit(innerMostLetASTNode, mtdForCurAST) + innerMostLetASTNode.depth = 0 + program = innerMostLetASTNode + dictNodeNameToOutVarStr[outputName] = curOutVarStr + outVarCt += 1 + return (program, dictNodeNameToOutVarStr) + def readSizeInfo(fileName): - allLines = None - with open(fileName) as f: - allLines = f.readlines() - sizeInfo = {} - for line in allLines: - tokens = line.split() - nodeName = tokens[0] - tokens = tokens[1:] - nodeOPSize = [] - if (not tokens): - nodeOPSize = [1] - else: - for curDimStr in tokens: - if (curDimStr == ''): continue - nodeOPSize.append(int(curDimStr)) - sizeInfo[nodeName] = nodeOPSize - return sizeInfo + allLines = None + with open(fileName) as f: + allLines = f.readlines() + sizeInfo = {} + for line in allLines: + tokens = line.split() + nodeName = tokens[0] + tokens = tokens[1:] + nodeOPSize = [] + if not tokens: + nodeOPSize = [1] + else: + for curDimStr in tokens: + if curDimStr == "": + continue + nodeOPSize.append(int(curDimStr)) + sizeInfo[nodeName] = nodeOPSize + return sizeInfo + # Since later on in the pipeline, the placeholder nodes which come up as cin statements # are to be excluded from the timing calculation, output all such PlaceHolder nodes together first. -# This doesn't violate the topological ordering because all such PlaceHolder nodes are leaf nodes +# This doesn't violate the topological ordering because all such PlaceHolder nodes are leaf nodes # in the graph. # This however extends live ranges of inputs and increases peak memory usage. # This also maintains the partial order between placeholder/variable nodes def prefixAllPlaceHolderNodes(graph): - allNodes = graph.getAllNodesRef() - placeHolderNodes = [] - variableNodes = [] - remNodes = [] - for curNode in allNodes: - if curNode.getOp() in ["Placeholder", "VariableV2"]: - assert(len(curNode.getInputsRef()) == 0) # Assert this is indeed a leaf node - placeHolderNodes.append(curNode) - else: - remNodes.append(curNode) - graph.setNodesList(placeHolderNodes + remNodes) + allNodes = graph.getAllNodesRef() + placeHolderNodes = [] + variableNodes = [] + remNodes = [] + for curNode in allNodes: + if curNode.getOp() in ["Placeholder", "VariableV2"]: + assert len(curNode.getInputsRef()) == 0 # Assert this is indeed a leaf node + placeHolderNodes.append(curNode) + else: + remNodes.append(curNode) + graph.setNodesList(placeHolderNodes + remNodes) + # List of Optimisations # 1. Split squared difference into (a-b)*(a-b) # 2. Reshape filter of depth separable convolution to convert it to a grouped convolution def simplifyGraph(graph, sizeInfo): - allNodes = graph.getAllNodesRef() - nodesMap = graph.getAllNodes() - newNodes = [] - inputsFixup = {} - for curNode in allNodes: - inputs = curNode.getInputsRef() - for i in range(len(inputs)): - if inputs[i] in inputsFixup: - inputs[i] = inputsFixup[inputs[i]] - if (curNode.getOp() == "SquaredDifference"): - sub = Graph.Node("Sub", inputs.copy(), curNode.getName() + "__sub") - mul = Graph.Node("Mul", [sub.getName(), sub.getName()], curNode.getName() + "__mul") - newNodes.append(sub) - newNodes.append(mul) - nodesMap[sub.getName()] = sub - nodesMap[mul.getName()] = mul - inputsFixup[curNode.getName()] = mul.getName() - nodesMap.pop(curNode.getName()) - elif (curNode.getOp() == "DepthwiseConv2dNative"): - filter_shape = sizeInfo[inputs[1]] - in_channels = filter_shape[2] - channel_multiplier = filter_shape[3] - output_channels = in_channels * channel_multiplier - # new filter shape = [FH, FW, 1, CI*CM] - new_filter_shape = filter_shape[0:2] + [1, output_channels] - reshape = Graph.Node("Reshape", [inputs[1]], curNode.getName() + "__reshape") - newNodes.append(reshape) - newNodes.append(curNode) - nodesMap[reshape.getName()] = reshape - inputs[1] = reshape.getName() - sizeInfo[reshape.getName()] = new_filter_shape - else: - newNodes.append(curNode) - graph.setNodesList(newNodes) + allNodes = graph.getAllNodesRef() + nodesMap = graph.getAllNodes() + newNodes = [] + inputsFixup = {} + for curNode in allNodes: + inputs = curNode.getInputsRef() + for i in range(len(inputs)): + if inputs[i] in inputsFixup: + inputs[i] = inputsFixup[inputs[i]] + if curNode.getOp() == "SquaredDifference": + sub = Graph.Node("Sub", inputs.copy(), curNode.getName() + "__sub") + mul = Graph.Node( + "Mul", [sub.getName(), sub.getName()], curNode.getName() + "__mul" + ) + newNodes.append(sub) + newNodes.append(mul) + nodesMap[sub.getName()] = sub + nodesMap[mul.getName()] = mul + inputsFixup[curNode.getName()] = mul.getName() + nodesMap.pop(curNode.getName()) + elif curNode.getOp() == "DepthwiseConv2dNative": + filter_shape = sizeInfo[inputs[1]] + in_channels = filter_shape[2] + channel_multiplier = filter_shape[3] + output_channels = in_channels * channel_multiplier + # new filter shape = [FH, FW, 1, CI*CM] + new_filter_shape = filter_shape[0:2] + [1, output_channels] + reshape = Graph.Node( + "Reshape", [inputs[1]], curNode.getName() + "__reshape" + ) + newNodes.append(reshape) + newNodes.append(curNode) + nodesMap[reshape.getName()] = reshape + inputs[1] = reshape.getName() + sizeInfo[reshape.getName()] = new_filter_shape + else: + newNodes.append(curNode) + graph.setNodesList(newNodes) + # We have to process all input nodes before output nodes. # However we cannot change the partial order of the placeholder and variable nodes. # The model weights are dumped from tensorflow in the original graphdef order and if # we don't adhere to that, different inputs will be read by the program. def arrange_input_before_output(graph): - allNodes = graph.getAllNodesRef() - visited = set() - already_sorted = True - for curNode in allNodes: - visited.add(curNode.getName()) - for inp in curNode.getInputsRef(): - if inp not in visited: - already_sorted = False - break - - # True almost all the time - if already_sorted: - return - - adjList = { i : [] for i in range(len(allNodes))} - position = { node.getName() : i for i,node in enumerate(allNodes)} - for i, curNode in enumerate(allNodes): - inputs = curNode.getInputsRef() - for inp in inputs: - adjList[position[inp]].append(i) - - # Additionally create edges between all placeholder and variable nodes - nodes_seen = [] - for i, curNode in reversed(list(enumerate(allNodes))): - if curNode.getOp() in ["Placeholder", "VariableV2"]: - adjList[i].extend(nodes_seen) - nodes_seen.append(i) - - no_nodes = len(allNodes) - visited = [False] * no_nodes - final_order = [] - - def topo_sort(v): - visited[v] = True - for i in adjList[v]: - if visited[i] == False: - topo_sort(i) - final_order.insert(0,v) - - for i in range(no_nodes): - if visited[i] == False: - topo_sort(i) - - assert len(final_order) == no_nodes, "Lost some nodes while sorting" - newNodes = [allNodes[i] for i in final_order] - graph.setNodesList(newNodes) - return + allNodes = graph.getAllNodesRef() + visited = set() + already_sorted = True + for curNode in allNodes: + visited.add(curNode.getName()) + for inp in curNode.getInputsRef(): + if inp not in visited: + already_sorted = False + break + + # True almost all the time + if already_sorted: + return + + adjList = {i: [] for i in range(len(allNodes))} + position = {node.getName(): i for i, node in enumerate(allNodes)} + for i, curNode in enumerate(allNodes): + inputs = curNode.getInputsRef() + for inp in inputs: + adjList[position[inp]].append(i) + + # Additionally create edges between all placeholder and variable nodes + nodes_seen = [] + for i, curNode in reversed(list(enumerate(allNodes))): + if curNode.getOp() in ["Placeholder", "VariableV2"]: + adjList[i].extend(nodes_seen) + nodes_seen.append(i) + + no_nodes = len(allNodes) + visited = [False] * no_nodes + final_order = [] + + def topo_sort(v): + visited[v] = True + for i in adjList[v]: + if visited[i] == False: + topo_sort(i) + final_order.insert(0, v) + + for i in range(no_nodes): + if visited[i] == False: + topo_sort(i) + + assert len(final_order) == no_nodes, "Lost some nodes while sorting" + newNodes = [allNodes[i] for i in final_order] + graph.setNodesList(newNodes) + return def process_tf_graph(filename): - sys.setrecursionlimit(10000) - - if os.path.isfile(filename): - folderName = os.path.dirname(filename) - elif os.path.isdir(filename): - folderName = filename - graphFileName = os.path.join(folderName, 'graphDef.mtdata') - graph = Graph.Graph() - with open(graphFileName) as file: - graph.readFromFilePointer(file) - - arrange_input_before_output(graph) - - # Read the sizeInfo also - sizeInfoFileName = os.path.join(folderName, 'sizeInfo.mtdata') - sizeInfo = readSizeInfo(sizeInfoFileName) - - # Tensorflow graph level optimisations - simplifyGraph(graph, sizeInfo) - # Place all PlaceHolder and variable nodes together at the beginning - prefixAllPlaceHolderNodes(graph) - - # Re-format the input names of nodes - for curNode in graph.getAllNodesRef(): - inputsRef = curNode.getInputsRef() - for i,curInput in enumerate(inputsRef): - if (curInput.startswith('^')): - # My hypothesis from empirical observation is that inputs which have '^' ahead of the node name - # denote control flow dependency and not data dependency. - # For all purposes for this compilation, control and data dependency is considered same. - # The reasoning being that everything is serial -- and graph execution is done in a - # a topological sort. - inputsRef[i] = curInput.split('^')[-1] - - # Create extra info dict - # Format : (sizeInfo) - extraInfoDict = {} - for k,v in sizeInfo.items(): - extraInfoDict[k] = (v,) - for curNode in graph.getAllNodesRef(): - if (curNode.getName() not in extraInfoDict): - extraInfoDict[curNode.getName()] = (None,) - - print("Generating code from TF graph def : ", graphFileName, " ...") - (program, dictNodeNameToOutVarStr) = generateIRCode(graph, extraInfoDict) - - print("SeeDot AST generation done. Pickling the AST.") - with open(os.path.join(folderName, 'astOutput.pkl'), 'wb') as f: - pickle.dump(program, f) + sys.setrecursionlimit(10000) + + if os.path.isfile(filename): + folderName = os.path.dirname(filename) + elif os.path.isdir(filename): + folderName = filename + graphFileName = os.path.join(folderName, "graphDef.mtdata") + graph = Graph.Graph() + with open(graphFileName) as file: + graph.readFromFilePointer(file) + + arrange_input_before_output(graph) + + # Read the sizeInfo also + sizeInfoFileName = os.path.join(folderName, "sizeInfo.mtdata") + sizeInfo = readSizeInfo(sizeInfoFileName) + + # Tensorflow graph level optimisations + simplifyGraph(graph, sizeInfo) + # Place all PlaceHolder and variable nodes together at the beginning + prefixAllPlaceHolderNodes(graph) + + # Re-format the input names of nodes + for curNode in graph.getAllNodesRef(): + inputsRef = curNode.getInputsRef() + for i, curInput in enumerate(inputsRef): + if curInput.startswith("^"): + # My hypothesis from empirical observation is that inputs which have '^' ahead of the node name + # denote control flow dependency and not data dependency. + # For all purposes for this compilation, control and data dependency is considered same. + # The reasoning being that everything is serial -- and graph execution is done in a + # a topological sort. + inputsRef[i] = curInput.split("^")[-1] + + # Create extra info dict + # Format : (sizeInfo) + extraInfoDict = {} + for k, v in sizeInfo.items(): + extraInfoDict[k] = (v,) + for curNode in graph.getAllNodesRef(): + if curNode.getName() not in extraInfoDict: + extraInfoDict[curNode.getName()] = (None,) + + print("Generating code from TF graph def : ", graphFileName, " ...") + (program, dictNodeNameToOutVarStr) = generateIRCode(graph, extraInfoDict) + + print("SeeDot AST generation done. Pickling the AST.") + with open(os.path.join(folderName, "astOutput.pkl"), "wb") as f: + pickle.dump(program, f) + if __name__ == "__main__": - if (len(sys.argv) < 2): - print("TF python file unspecified.", file=sys.stderr) - exit(1) + if len(sys.argv) < 2: + print("TF python file unspecified.", file=sys.stderr) + exit(1) - filename = sys.argv[1] - process_tf_graph(filename) \ No newline at end of file + filename = sys.argv[1] + process_tf_graph(filename) diff --git a/Athos/TFCompiler/TFNodesAST.py b/Athos/TFCompiler/TFNodesAST.py index e2faff7c..a3f34dba 100644 --- a/Athos/TFCompiler/TFNodesAST.py +++ b/Athos/TFCompiler/TFNodesAST.py @@ -1,4 +1,4 @@ -''' +""" Authors: Nishant Kumar. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import Graph import AST.AST as AST @@ -29,750 +29,1363 @@ # Contains code for each of the TF nodes encountered in the benchmarks. # For each such TF node, outputs the corresponding SeeDot AST. + class TFNodesAST: - class UninterpFuncCallNames(Enum): - ''' - NOTE : SeeDot when compiling uninterpreted function calls, adds a new declaration for each uninterpreted function call. - ''' - Input = auto() - CreateCopy = auto() - CreateIdentity = auto() - CreateTensor = auto() - CopyTensor = auto() - Const = auto() - Cast = auto() - TruncatedNormal = auto() - RandomUniform = auto() - Tile = auto() - MaxPool = auto() - Pack = auto() - Concat = auto() - ExpandDims = auto() - MaxPoolGrad = auto() - Conv2DBackpropInput = auto() - Conv2DBackpropFilter = auto() - AvgPool = auto() - Pad = auto() - Squeeze = auto() - - def getOperatorsIdx(token): - return AST.Operators.convSymbolToEnumValue(token) - - def MatMul(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) - inp1Str = dictNodeNameToOutVarStr[inputsRef[0]] - inp2Str = dictNodeNameToOutVarStr[inputsRef[1]] - inp1AST = AST.ID(inp1Str) - inp2AST = AST.ID(inp2Str) - - attrMapRef = curNode.getAttrMapRef() - transposeABool = transposeBBool = False - if ("transpose_a" in attrMapRef): - transposeABool = attrMapRef["transpose_a"].getB() - if ("transpose_b" in attrMapRef): - transposeBBool = attrMapRef["transpose_b"].getB() - if (transposeABool): inp1AST = AST.Transp(inp1AST) - if (transposeBBool): inp2AST = AST.Transp(inp2AST) - return (None, { curNode.getName() : AST.BOp(inp1AST, TFNodesAST.getOperatorsIdx('*'), inp2AST)}) - - def Placeholder(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] - curNodeInputType = curNode.getAttrMapRef()["dtype"].getDataType() - assert(curNodeInputType is not Graph.DataTypeEnum.DT_INVALID) - - # NOTE: There has to be some way for Athos to differentiate model from image, since in the compiled code - # (in the scenario of secure inference), model is input by server and image by client. - # We assume in the following that the PlaceHolder op node represents the image and - # all model parameters are represented using Variable op nodes. - # Hence, in the call to AST.Input, we pass inputByParty=1. - - return (None, { curNode.getName() : AST.Input(curNodeShapeLi, curNodeInputType.name, isSecret=True, inputByParty=AST.Party.CLIENT)}) - - def Equal(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) - return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - TFNodesAST.getOperatorsIdx('=='), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - )}) - - def Identity(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - #In SeeDot, J2=J1 creates a new reference for J1 -- so - # the corresponding code in Seedot cannot simply be J2 = J1. - # Instead create a new tensor first and then assign the old one to the new one. - inputsRef = curNode.getInputsRef() - assert(len(inputsRef)==1) - - curNodeDataType = curNode.getAttrMapRef()["T"].getDataType() - assert(curNodeDataType is not Graph.DataTypeEnum.DT_INVALID) - - curNodeShape = extraNodeInfoDict[curNode.getName()][0] - retAST = AST.UninterpFuncCall(curNodeShape, - TFNodesAST.UninterpFuncCallNames.CreateIdentity.name, - [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])]) - return (None, { curNode.getName() : retAST}) - - def Add(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) - return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - TFNodesAST.getOperatorsIdx('+'), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - )}) - def AddV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) - return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - TFNodesAST.getOperatorsIdx('+'), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - )}) - - def Mul(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) - return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - TFNodesAST.getOperatorsIdx('.*'), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - )}) - - def Neg(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 1) - return (None, { curNode.getName() : AST.UOp(TFNodesAST.getOperatorsIdx('-'), - AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]) - )}) - - def Sub(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) - return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - TFNodesAST.getOperatorsIdx('+'), - AST.UOp(TFNodesAST.getOperatorsIdx('-'), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - ))}) - - def Floor(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 1) - return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('floor'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))}) - - def RealDiv(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) - return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - TFNodesAST.getOperatorsIdx('./'), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - )}) - - def FloorDiv(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) - realDivAST = AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - TFNodesAST.getOperatorsIdx('./'), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - ) - return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('floor'), realDivAST)}) - - def VariableV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - curNodeShapeLi = curNode.getAttrMapRef()["shape"].getShape().getDimRef()[:] - curNodeInputType = curNode.getAttrMapRef()["dtype"].getDataType() - - # NOTE : since this becomes an input node right now, i have also added to be prefixed at top in ProcessTFGraph::prefixAllPlaceHolderNodes() - # NOTE: There has to be some way for Athos to differentiate model from image, since in the compiled code - # (in the scenario of secure inference), model is input by server and image by client. - # We assume in the following that the PlaceHolder op node represents the image and - # all model parameters are represented using Variable op nodes. - # Hence, in the call to AST.Input, we pass inputByParty as SERVER. - return (None, { curNode.getName() : AST.Input(curNodeShapeLi, curNodeInputType.name, isSecret=True, inputByParty=AST.Party.SERVER)}) - - def Const(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - assert(len(curNode.getInputsRef()) == 0) - tensor = curNode.getAttrMapRef()["value"].getTensor() - curNodeDataType = curNode.getAttrMapRef()["dtype"].getDataType() - curNodeShape = tensor.getShapeRef()[:] #create a different copy to not change the original copy - - tensorConstantVal = tensor.getConstantVal() - if tensorConstantVal is not None: - # Use uinterpreted call of CreateTensor to create the tensor and fill it with a constant value - dataPassed = None - if curNodeDataType == Graph.DataTypeEnum.DT_INT32: - dataPassed = AST.Int(tensorConstantVal, 32, isSecret=False) - elif curNodeDataType == Graph.DataTypeEnum.DT_FLOAT: - dataPassed = AST.Float(tensorConstantVal, isSecret=False) - else: - assert False - - if (len(curNodeShape) == 0): - # This is a constant element - retAST = dataPassed - else: - retAST = AST.UninterpFuncCall(curNodeShape, - TFNodesAST.UninterpFuncCallNames.CreateTensor.name, - [dataPassed], - isSecret=False) - else: - # The tensor content is given as byte array. Extract val array from the byte array and create ast. - if curNodeDataType == Graph.DataTypeEnum.DT_INT32: - dataPassed = list(map(lambda x: AST.Int(x, 32, isSecret=False), tensor.getContentAsValArr()[:])) - elif curNodeDataType == Graph.DataTypeEnum.DT_FLOAT: - dataPassed = list(map(lambda x: AST.Float(x, isSecret=False), tensor.getContentAsValArr()[:])) - else: - assert False - retAST = AST.Decl(curNodeShape, None, dataPassed, isSecret=False) - return (None, { curNode.getName() : retAST}) - - def Relu(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef)==1) - return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('relu'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))}) - - def Tanh(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef)==1) - return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('tanh'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))}) - - def Sqrt(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef)==1) - return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('sqrt'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))}) - - def Rsqrt(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef)==1) - return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('rsqrt'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))}) - - def Sigmoid(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef)==1) - return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('sigmoid'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))}) - - def Shape(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef)==1) - return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('shape'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))}) - - def Cast(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 1) - sourceType = curNode.getAttrMapRef()["SrcT"].getDataType() - destType = curNode.getAttrMapRef()["DstT"].getDataType() - return (None, { curNode.getName() : AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], - TFNodesAST.UninterpFuncCallNames.Cast.name, - [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - AST.ID(sourceType.name), - AST.ID(destType.name) - ])}) - - def ZerosLike(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef)==1) - curNodeOutputType = curNode.getAttrMapRef()["T"].getDataType() - assert(curNodeOutputType is not Graph.DataTypeEnum.DT_INVALID) - retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], - TFNodesAST.UninterpFuncCallNames.CreateTensor.name, - [AST.Int(0, isSecret=False)], - isSecret=False) - return (None, { curNode.getName() : retAST}) - - def Fill(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) - curNodeOutputShape = extraNodeInfoDict[inputsRef[0]][0] - assert(len(curNodeOutputShape) == 1) #inputsRef[0] denotes a shape and should have a rank of 1 - - curNodeOutputType = curNode.getAttrMapRef()["T"].getDataType() - assert(curNodeOutputType is not Graph.DataTypeEnum.DT_INVALID) - - retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], - TFNodesAST.UninterpFuncCallNames.CreateTensor.name, - [AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) ], - isSecret=False) - return (None, { curNode.getName() : retAST}) - - def Reshape(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - #assert(len(inputsRef) == 2) - return (None, { curNode.getName() : AST.Reshape(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), extraNodeInfoDict[curNode.getName()][0], None)}) - - def helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD = None, FD = None, strideD = None): - if imgD: - assert(FD) - assert(strideD) - zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1 - if (paddingUsedStr == "SAME"): - # Reference for following: - # https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn - totalPaddingH = totalPaddingW = totalPaddingD = 0 - - if (imgH % strideH == 0): - totalPaddingH = max(FH - strideH, 0) - else: - totalPaddingH = max(FH - (imgH % strideH), 0) - - if (imgW % strideW == 0): - totalPaddingW = max(FW - strideW, 0) - else: - totalPaddingW = max(FW - (imgW % strideW), 0) - - if imgD: - if (imgD % strideD == 0): - totalPaddingD = max(FD - strideD, 0) - else: - totalPaddingD = max(FD - (imgD % strideD), 0) - - zPadHLeft = totalPaddingH // 2 - zPadHRight = totalPaddingH - zPadHLeft - - zPadWLeft = totalPaddingW // 2 - zPadWRight = totalPaddingW - zPadWLeft - - zPadDLeft = totalPaddingD // 2 - zPadDRight = totalPaddingD - zPadDLeft - - elif (paddingUsedStr == "VALID"): - zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = 0 - else: - zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1 - - if imgD: - return [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] - else: - return [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] - - def Conv2D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef)==2) - - stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() - assert(stridesUsed[0]==1 and stridesUsed[3]==1) - strideH = stridesUsed[1] - strideW = stridesUsed[2] - - inputShape = extraNodeInfoDict[inputsRef[0]][0] - imgH = inputShape[1] - imgW = inputShape[2] - - filterShape = extraNodeInfoDict[inputsRef[1]][0] - FH = filterShape[0] - FW = filterShape[1] - - paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() - - [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, - FH, FW, - strideH, strideW, - paddingUsedStr - ) - - options = {} - options[AST.PaddingKeysDict.FH] = FH - options[AST.PaddingKeysDict.FW] = FW - options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft - options[AST.PaddingKeysDict.zPadHRight] = zPadHRight - options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft - options[AST.PaddingKeysDict.zPadWRight] = zPadWRight - options[AST.PaddingKeysDict.strideH] = strideH - options[AST.PaddingKeysDict.strideW] = strideW - return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - TFNodesAST.getOperatorsIdx('#'), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), - options)}) - - # A depthwise separable convolution is equivalent to a grouped convolution - # with no. of groups = the no. of input channels (G=CI) - # This however requires a reshape of the filter. - # Regular filter is [ FH, FW, CI, CM (channel_multiplier)] - # Doing depthwise conv results in CO = CI * CM - # Grouped conv expects [FH, FW, CI/G, (CO/G)*G] - # So we reshape to [FH, FW, 1, CI * CM] - def DepthwiseConv2dNative(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef)==2) - # Reshape of filter is done in simplifyGraph - img_shape = extraNodeInfoDict[inputsRef[0]][0] - in_channels = img_shape[3] #NHWC - groups = in_channels - _ , nodeToAST = TFNodesAST.Conv2D(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict) - nodeToAST[curNode.getName()].options[AST.PaddingKeysDict.group] = groups - return (None, nodeToAST) - - def Conv3D(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef)==2) - - stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() - assert(stridesUsed[0]==1 and stridesUsed[4]==1) - strideD = stridesUsed[1] - strideH = stridesUsed[2] - strideW = stridesUsed[3] - - inputShape = extraNodeInfoDict[inputsRef[0]][0] - imgD = inputShape[1] - imgH = inputShape[2] - imgW = inputShape[3] - - filterShape = extraNodeInfoDict[inputsRef[1]][0] - FD = filterShape[0] - FH = filterShape[1] - FW = filterShape[2] - - paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() - - [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD, FD, strideD ) - - options = {} - options[AST.PaddingKeysDict.FD] = FD - options[AST.PaddingKeysDict.FH] = FH - options[AST.PaddingKeysDict.FW] = FW - options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft - options[AST.PaddingKeysDict.zPadDRight] = zPadDRight - options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft - options[AST.PaddingKeysDict.zPadHRight] = zPadHRight - options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft - options[AST.PaddingKeysDict.zPadWRight] = zPadWRight - options[AST.PaddingKeysDict.strideD] = strideD - options[AST.PaddingKeysDict.strideH] = strideH - options[AST.PaddingKeysDict.strideW] = strideW - options[AST.PaddingKeysDict.ConvDim] = 3 - return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - TFNodesAST.getOperatorsIdx('#'), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), - options)}) - - def Conv3DBackpropInputV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef)==3) #output_shape, filter, input - - stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() - assert(stridesUsed[0]==1 and stridesUsed[4]==1) - strideD = stridesUsed[1] - strideH = stridesUsed[2] - strideW = stridesUsed[3] - - filterShape = extraNodeInfoDict[inputsRef[1]][0] - FD = filterShape[0] - FH = filterShape[1] - FW = filterShape[2] - - inputShape = extraNodeInfoDict[inputsRef[2]][0] - inputD = inputShape[1] - inputH = inputShape[2] - inputW = inputShape[3] - - outputShape = extraNodeInfoDict[curNode.getName()][0] - outputD = outputShape[1] - outputH = outputShape[2] - outputW = outputShape[3] - - paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() - - # Important: Using outputH and outputW in the below is not an error! - # For convTranspose, the parameters passed in the node are of the conv of which this convTranspose is an inverse. - # Which is why the call to helper_findPadding makes sense. - # The zPads below are of the conv of which this convTranspose is an inverse. - [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(outputH, outputW, FH, FW, strideH, strideW, paddingUsedStr, imgD = outputD, FD = FD, strideD = strideD) - - options = {} - options[AST.PaddingKeysDict.FD] = FD - options[AST.PaddingKeysDict.FH] = FH - options[AST.PaddingKeysDict.FW] = FW - options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft - options[AST.PaddingKeysDict.zPadDRight] = zPadDRight - options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft - options[AST.PaddingKeysDict.zPadHRight] = zPadHRight - options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft - options[AST.PaddingKeysDict.zPadWRight] = zPadWRight - options[AST.PaddingKeysDict.strideD] = strideD - options[AST.PaddingKeysDict.strideH] = strideH - options[AST.PaddingKeysDict.strideW] = strideW - options[AST.PaddingKeysDict.ConvDim] = 3 - options[AST.PaddingKeysDict.outputImgD] = outputD - options[AST.PaddingKeysDict.outputImgH] = outputH - options[AST.PaddingKeysDict.outputImgW] = outputW - return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), - TFNodesAST.getOperatorsIdx('#T'), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), - options)}) - - def helper_processPool(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict, typeOfPool:str): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef)==1) - - options = {} - - stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() - assert((stridesUsed[0] == 1) and (stridesUsed[3] == 1)) - strideH = stridesUsed[1] - strideW = stridesUsed[2] - - kSizeUsed = curNode.getAttrMapRef()["ksize"].getList().getILi() - assert((kSizeUsed[0] == 1) and (kSizeUsed[3] == 1)) - kSizeH = kSizeUsed[1] - kSizeW = kSizeUsed[2] - - inputShape = extraNodeInfoDict[inputsRef[0]][0] - imgH = inputShape[1] - imgW = inputShape[2] - - paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() - [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, - kSizeH, kSizeW, - strideH, strideW, - paddingUsedStr - ) - - poolType = None - if typeOfPool=='MAXPOOL': poolType = AST.Pool.PoolType.MaxPool - elif typeOfPool=='AVGPOOL': poolType = AST.Pool.PoolType.AvgPool - else: - print("Unknown type of pooling layer.", file=sys.stderr) - assert(False) - return (None, { curNode.getName() : AST.Pool(poolType, - AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - { - AST.PaddingKeysDict.FH: kSizeH, - AST.PaddingKeysDict.FW: kSizeW, - AST.PaddingKeysDict.zPadHLeft: zPadHLeft, - AST.PaddingKeysDict.zPadHRight: zPadHRight, - AST.PaddingKeysDict.zPadWLeft: zPadWLeft, - AST.PaddingKeysDict.zPadWRight: zPadWRight, - AST.PaddingKeysDict.strideH: strideH, - AST.PaddingKeysDict.strideW: strideW - } - )}) - - def MaxPool(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - return TFNodesAST.helper_processPool(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict, 'MAXPOOL') - - def AvgPool(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - return TFNodesAST.helper_processPool(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict, 'AVGPOOL') - - def ConcatV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - N = curNode.getAttrMapRef()["N"].getI() - assert(len(inputsRef) == N+1) #One extra for axis - #TODO : Since the axis of concat is constant, therefore, its known here - the input's sizes along that dim should be - # passed as input to the below function. - # For now hardcoding. - retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], - TFNodesAST.UninterpFuncCallNames.Concat.name + str(N) + 'T', - list(map(lambda x : AST.ID(dictNodeNameToOutVarStr[x]), inputsRef)), - outputDiffInpDims=1 - ) - return (None, { curNode.getName() : retAST}) - - def ExpandDims(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) - retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], - TFNodesAST.UninterpFuncCallNames.ExpandDims.name, - list(map(lambda x : AST.ID(dictNodeNameToOutVarStr[x]), inputsRef))) - return (None, { curNode.getName() : retAST}) - - def Slice(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 3) - beginNode = graph.__getitem__(inputsRef[1]) - sizeNode = graph.__getitem__(inputsRef[2]) - assert beginNode.getAttrVal("value") is not None, "begin {} of Slice node {} has to be a constant".format(inputsRef[1], curNode.getName()) - assert sizeNode.getAttrVal("value") is not None, "size {} of Slice node {} has to be a constant".format(inputsRef[2], curNode.getName()) - begin = beginNode.getAttrVal("value").getTensor().getContentAsValArr() - size = sizeNode.getAttrVal("value").getTensor().getContentAsValArr() - assert begin is not None - assert size is not None - assert len(begin) == len(size) - subscriptRanges = [] - for i in range(0,len(size)): - subscriptRanges.append((begin[i], begin[i] + size[i] - 1)) - - return (None, { curNode.getName() : AST.Slice(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - subscriptRanges)}) - - def Tile(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) - return (None, { curNode.getName() : AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], - TFNodesAST.UninterpFuncCallNames.Tile.name, - [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])])}) - - def Sum(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - attrMapRef = curNode.getAttrMapRef() - assert(len(inputsRef) == 2) - keepdims = False - if ("keep_dims" in attrMapRef): - keepdims = attrMapRef["keep_dims"].getB() - - reductionAxesNodeName = inputsRef[1] - redAxesN = graph.__getitem__(reductionAxesNodeName) - redAxesT = redAxesN.getAttrVal("value").getTensor() - rank = redAxesT.getShapeRef().getRank() - if rank != 0: - reductionAxesList = redAxesT.getContentAsValArr() - else: - reductionAxesList = [redAxesT.getConstantVal()] - - curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] - return (None, { curNode.getName() : AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - keepdims, - curNodeShapeLi, - TFNodesAST.getOperatorsIdx('+'), - reductionAxesList)}) - - def Mean(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - attrMapRef = curNode.getAttrMapRef() - assert(len(inputsRef) == 2) - keepdims = False - if ("keep_dims" in attrMapRef): - keepdims = attrMapRef["keep_dims"].getB() - - reductionAxesNodeName = inputsRef[1] - redAxesN = graph.__getitem__(reductionAxesNodeName) - redAxesT = redAxesN.getAttrVal("value").getTensor() - rank = redAxesT.getShapeRef().getRank() - if rank != 0: - reductionAxesList = redAxesT.getContentAsValArr() - else: - reductionAxesList = [redAxesT.getConstantVal()] - - curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] - return (None, { curNode.getName() : AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - keepdims, - curNodeShapeLi, - TFNodesAST.getOperatorsIdx('mean'), - reductionAxesList)}) - - def ArgMax(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) - return (None, { curNode.getName() : AST.ArgMax(extraNodeInfoDict[curNode.getName()][0], - AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), - extraNodeInfoDict[inputsRef[0]][0])}) - - def NoOp(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - return (None, { curNode.getName() : None}) - - def Square(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 1) - return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - TFNodesAST.getOperatorsIdx('.*'), - AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]) - )}) - - def Pad(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - # Mode refers to 'CONSTANT', 'REFLECT' or 'SYMMETRIC' - mode = 0 - if ("mode" in curNode.getAttrMapRef()): - mode = curNode.getAttrMapRef()["mode"].getI() - - constant_values = 0 - if ("constant_values" in curNode.getAttrMapRef()): - constant_values = curNode.getAttrMapRef()["constant_values"].getI() - - assert(mode == 0 and constant_values == 0) # For now to make life easy - deal with SYMMETRIC AND REFLECT when time comes - inputsRef = curNode.getInputsRef() - inputTensorShapeLi = extraNodeInfoDict[inputsRef[0]][0] - return (None, { curNode.getName() : AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], - TFNodesAST.UninterpFuncCallNames.Pad.name, - [ - AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - ], - outputDiffInpDims=1 - )}) - - def FusedBatchNorm(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - return (None, { curNode.getName() : AST.FusedBatchNorm(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), - AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), - )}) - def FusedBatchNormV3(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - return (None, { curNode.getName() : AST.FusedBatchNorm(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), - AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), - )}) - - def Transpose(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) - permNodeName = inputsRef[1] - # We need to fetch the tensor value of the perm Node - permNode = graph.__getitem__(permNodeName) - permTensor = permNode.getAttrVal("value").getTensor() - permList = permTensor.getContentAsValArr() - assert(permTensor.getDType().kind == "i") - assert(permTensor.getShapeRef().getRank() == 1) - return (None, { curNode.getName() : AST.Transpose(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), permList)}) - - def Split(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) - axisNodeName = inputsRef[0] # split_dim input. Has to be a constant. We don't support dynamic codegen yet - axisNode = graph.__getitem__(axisNodeName) - axisTensor = axisNode.getAttrVal("value").getTensor() - axis = axisTensor.getConstantVal() - numSplits = curNode.getAttrVal("num_split").getI() - inputTensorShape = extraNodeInfoDict[inputsRef[1]][0] - assert(axis < len(inputTensorShape)) - assert(inputTensorShape[axis] % numSplits == 0) #Should perfectly split - sizeAlongSplitDim = int(inputTensorShape[axis]/numSplits) - outputAsts = {} - for i in range(0, numSplits): - output_name = curNode.getName() - if i != 0: - output_name += ":" + str(i) - subscriptRanges = [] - for j in range(0, len(inputTensorShape)): - start = 0 - end = inputTensorShape[j] - 1 - if j == axis: - start = i*sizeAlongSplitDim - end = start + sizeAlongSplitDim - 1 - subscriptRanges.append((start,end)) - outputAsts[output_name] = AST.Slice(AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), subscriptRanges) - return (None, outputAsts) - - def Squeeze(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - inputTensorShape = extraNodeInfoDict[inputsRef[0]][0] - inputTensorRank = len(inputTensorShape) - - squeezeDims = curNode.getAttrMapRef()["squeeze_dims"].getList().getILi() - squeezeDimsRank = len(squeezeDims) - - return (None, { curNode.getName() : AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], - TFNodesAST.UninterpFuncCallNames.Squeeze.name, - list(map(lambda x : AST.Int(x, 32, isSecret=False), squeezeDims)) + - [ - AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]) - ] - )}) - - def BiasAdd(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - assert(len(inputsRef) == 2) - return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), - TFNodesAST.getOperatorsIdx('+'), - AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) - )}) - - def ReadVariableOp(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - return (None, { curNode.getName() : AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])}) - - def Softmax(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - return (None, { curNode.getName() : AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])}) - - def StopGradient(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - inputsRef = curNode.getInputsRef() - return (None, { curNode.getName() : AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])}) - - def VarHandleOp(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): - return TFNodesAST.VariableV2(graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict) + class UninterpFuncCallNames(Enum): + """ + NOTE : SeeDot when compiling uninterpreted function calls, adds a new declaration for each uninterpreted function call. + """ + + Input = auto() + CreateCopy = auto() + CreateIdentity = auto() + CreateTensor = auto() + CopyTensor = auto() + Const = auto() + Cast = auto() + TruncatedNormal = auto() + RandomUniform = auto() + Tile = auto() + MaxPool = auto() + Pack = auto() + Concat = auto() + ExpandDims = auto() + MaxPoolGrad = auto() + Conv2DBackpropInput = auto() + Conv2DBackpropFilter = auto() + AvgPool = auto() + Pad = auto() + Squeeze = auto() + + def getOperatorsIdx(token): + return AST.Operators.convSymbolToEnumValue(token) + + def MatMul( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + inp1Str = dictNodeNameToOutVarStr[inputsRef[0]] + inp2Str = dictNodeNameToOutVarStr[inputsRef[1]] + inp1AST = AST.ID(inp1Str) + inp2AST = AST.ID(inp2Str) + + attrMapRef = curNode.getAttrMapRef() + transposeABool = transposeBBool = False + if "transpose_a" in attrMapRef: + transposeABool = attrMapRef["transpose_a"].getB() + if "transpose_b" in attrMapRef: + transposeBBool = attrMapRef["transpose_b"].getB() + if transposeABool: + inp1AST = AST.Transp(inp1AST) + if transposeBBool: + inp2AST = AST.Transp(inp2AST) + return ( + None, + { + curNode.getName(): AST.BOp( + inp1AST, TFNodesAST.getOperatorsIdx("*"), inp2AST + ) + }, + ) + + def Placeholder( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] + curNodeInputType = curNode.getAttrMapRef()["dtype"].getDataType() + assert curNodeInputType is not Graph.DataTypeEnum.DT_INVALID + + # NOTE: There has to be some way for Athos to differentiate model from image, since in the compiled code + # (in the scenario of secure inference), model is input by server and image by client. + # We assume in the following that the PlaceHolder op node represents the image and + # all model parameters are represented using Variable op nodes. + # Hence, in the call to AST.Input, we pass inputByParty=1. + + return ( + None, + { + curNode.getName(): AST.Input( + curNodeShapeLi, + curNodeInputType.name, + isSecret=True, + inputByParty=AST.Party.CLIENT, + ) + }, + ) + + def Equal( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + return ( + None, + { + curNode.getName(): AST.BOp( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + TFNodesAST.getOperatorsIdx("=="), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + ) + }, + ) + + def Identity( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + # In SeeDot, J2=J1 creates a new reference for J1 -- so + # the corresponding code in Seedot cannot simply be J2 = J1. + # Instead create a new tensor first and then assign the old one to the new one. + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 1 + + curNodeDataType = curNode.getAttrMapRef()["T"].getDataType() + assert curNodeDataType is not Graph.DataTypeEnum.DT_INVALID + + curNodeShape = extraNodeInfoDict[curNode.getName()][0] + retAST = AST.UninterpFuncCall( + curNodeShape, + TFNodesAST.UninterpFuncCallNames.CreateIdentity.name, + [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])], + ) + return (None, {curNode.getName(): retAST}) + + def Add( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + return ( + None, + { + curNode.getName(): AST.BOp( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + TFNodesAST.getOperatorsIdx("+"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + ) + }, + ) + + def AddV2( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + return ( + None, + { + curNode.getName(): AST.BOp( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + TFNodesAST.getOperatorsIdx("+"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + ) + }, + ) + + def Mul( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + return ( + None, + { + curNode.getName(): AST.BOp( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + TFNodesAST.getOperatorsIdx(".*"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + ) + }, + ) + + def Neg( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 1 + return ( + None, + { + curNode.getName(): AST.UOp( + TFNodesAST.getOperatorsIdx("-"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + ) + }, + ) + + def Sub( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + return ( + None, + { + curNode.getName(): AST.BOp( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + TFNodesAST.getOperatorsIdx("+"), + AST.UOp( + TFNodesAST.getOperatorsIdx("-"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + ), + ) + }, + ) + + def Floor( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 1 + return ( + None, + { + curNode.getName(): AST.Func( + TFNodesAST.getOperatorsIdx("floor"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + ) + }, + ) + + def RealDiv( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + return ( + None, + { + curNode.getName(): AST.BOp( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + TFNodesAST.getOperatorsIdx("./"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + ) + }, + ) + + def FloorDiv( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + realDivAST = AST.BOp( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + TFNodesAST.getOperatorsIdx("./"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + ) + return ( + None, + { + curNode.getName(): AST.Func( + TFNodesAST.getOperatorsIdx("floor"), realDivAST + ) + }, + ) + + def VariableV2( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + curNodeShapeLi = curNode.getAttrMapRef()["shape"].getShape().getDimRef()[:] + curNodeInputType = curNode.getAttrMapRef()["dtype"].getDataType() + + # NOTE : since this becomes an input node right now, i have also added to be prefixed at top in ProcessTFGraph::prefixAllPlaceHolderNodes() + # NOTE: There has to be some way for Athos to differentiate model from image, since in the compiled code + # (in the scenario of secure inference), model is input by server and image by client. + # We assume in the following that the PlaceHolder op node represents the image and + # all model parameters are represented using Variable op nodes. + # Hence, in the call to AST.Input, we pass inputByParty as SERVER. + return ( + None, + { + curNode.getName(): AST.Input( + curNodeShapeLi, + curNodeInputType.name, + isSecret=True, + inputByParty=AST.Party.SERVER, + ) + }, + ) + + def Const( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + assert len(curNode.getInputsRef()) == 0 + tensor = curNode.getAttrMapRef()["value"].getTensor() + curNodeDataType = curNode.getAttrMapRef()["dtype"].getDataType() + curNodeShape = tensor.getShapeRef()[ + : + ] # create a different copy to not change the original copy + + tensorConstantVal = tensor.getConstantVal() + if tensorConstantVal is not None: + # Use uinterpreted call of CreateTensor to create the tensor and fill it with a constant value + dataPassed = None + if curNodeDataType == Graph.DataTypeEnum.DT_INT32: + dataPassed = AST.Int(tensorConstantVal, 32, isSecret=False) + elif curNodeDataType == Graph.DataTypeEnum.DT_FLOAT: + dataPassed = AST.Float(tensorConstantVal, isSecret=False) + else: + assert False + + if len(curNodeShape) == 0: + # This is a constant element + retAST = dataPassed + else: + retAST = AST.UninterpFuncCall( + curNodeShape, + TFNodesAST.UninterpFuncCallNames.CreateTensor.name, + [dataPassed], + isSecret=False, + ) + else: + # The tensor content is given as byte array. Extract val array from the byte array and create ast. + if curNodeDataType == Graph.DataTypeEnum.DT_INT32: + dataPassed = list( + map( + lambda x: AST.Int(x, 32, isSecret=False), + tensor.getContentAsValArr()[:], + ) + ) + elif curNodeDataType == Graph.DataTypeEnum.DT_FLOAT: + dataPassed = list( + map( + lambda x: AST.Float(x, isSecret=False), + tensor.getContentAsValArr()[:], + ) + ) + else: + assert False + retAST = AST.Decl(curNodeShape, None, dataPassed, isSecret=False) + return (None, {curNode.getName(): retAST}) + + def Relu( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 1 + return ( + None, + { + curNode.getName(): AST.Func( + TFNodesAST.getOperatorsIdx("relu"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + ) + }, + ) + + def Tanh( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 1 + return ( + None, + { + curNode.getName(): AST.Func( + TFNodesAST.getOperatorsIdx("tanh"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + ) + }, + ) + + def Sqrt( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 1 + return ( + None, + { + curNode.getName(): AST.Func( + TFNodesAST.getOperatorsIdx("sqrt"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + ) + }, + ) + + def Rsqrt( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 1 + return ( + None, + { + curNode.getName(): AST.Func( + TFNodesAST.getOperatorsIdx("rsqrt"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + ) + }, + ) + + def Sigmoid( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 1 + return ( + None, + { + curNode.getName(): AST.Func( + TFNodesAST.getOperatorsIdx("sigmoid"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + ) + }, + ) + + def Shape( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 1 + return ( + None, + { + curNode.getName(): AST.Func( + TFNodesAST.getOperatorsIdx("shape"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + ) + }, + ) + + def Cast( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 1 + sourceType = curNode.getAttrMapRef()["SrcT"].getDataType() + destType = curNode.getAttrMapRef()["DstT"].getDataType() + return ( + None, + { + curNode.getName(): AST.UninterpFuncCall( + extraNodeInfoDict[curNode.getName()][0], + TFNodesAST.UninterpFuncCallNames.Cast.name, + [ + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + AST.ID(sourceType.name), + AST.ID(destType.name), + ], + ) + }, + ) + + def ZerosLike( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 1 + curNodeOutputType = curNode.getAttrMapRef()["T"].getDataType() + assert curNodeOutputType is not Graph.DataTypeEnum.DT_INVALID + retAST = AST.UninterpFuncCall( + extraNodeInfoDict[curNode.getName()][0], + TFNodesAST.UninterpFuncCallNames.CreateTensor.name, + [AST.Int(0, isSecret=False)], + isSecret=False, + ) + return (None, {curNode.getName(): retAST}) + + def Fill( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + curNodeOutputShape = extraNodeInfoDict[inputsRef[0]][0] + assert ( + len(curNodeOutputShape) == 1 + ) # inputsRef[0] denotes a shape and should have a rank of 1 + + curNodeOutputType = curNode.getAttrMapRef()["T"].getDataType() + assert curNodeOutputType is not Graph.DataTypeEnum.DT_INVALID + + retAST = AST.UninterpFuncCall( + extraNodeInfoDict[curNode.getName()][0], + TFNodesAST.UninterpFuncCallNames.CreateTensor.name, + [AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])], + isSecret=False, + ) + return (None, {curNode.getName(): retAST}) + + def Reshape( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + # assert(len(inputsRef) == 2) + return ( + None, + { + curNode.getName(): AST.Reshape( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + extraNodeInfoDict[curNode.getName()][0], + None, + ) + }, + ) + + def helper_findPadding( + imgH, + imgW, + FH, + FW, + strideH, + strideW, + paddingUsedStr, + imgD=None, + FD=None, + strideD=None, + ): + if imgD: + assert FD + assert strideD + zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1 + if paddingUsedStr == "SAME": + # Reference for following: + # https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn + totalPaddingH = totalPaddingW = totalPaddingD = 0 + + if imgH % strideH == 0: + totalPaddingH = max(FH - strideH, 0) + else: + totalPaddingH = max(FH - (imgH % strideH), 0) + + if imgW % strideW == 0: + totalPaddingW = max(FW - strideW, 0) + else: + totalPaddingW = max(FW - (imgW % strideW), 0) + + if imgD: + if imgD % strideD == 0: + totalPaddingD = max(FD - strideD, 0) + else: + totalPaddingD = max(FD - (imgD % strideD), 0) + + zPadHLeft = totalPaddingH // 2 + zPadHRight = totalPaddingH - zPadHLeft + + zPadWLeft = totalPaddingW // 2 + zPadWRight = totalPaddingW - zPadWLeft + + zPadDLeft = totalPaddingD // 2 + zPadDRight = totalPaddingD - zPadDLeft + + elif paddingUsedStr == "VALID": + zPadHLeft = zPadHRight = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = 0 + else: + zPadHLeft = ( + zPadHRight + ) = zPadWLeft = zPadWRight = zPadDLeft = zPadDRight = -1 + + if imgD: + return [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] + else: + return [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] + + def Conv2D( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + + stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() + assert stridesUsed[0] == 1 and stridesUsed[3] == 1 + strideH = stridesUsed[1] + strideW = stridesUsed[2] + + inputShape = extraNodeInfoDict[inputsRef[0]][0] + imgH = inputShape[1] + imgW = inputShape[2] + + filterShape = extraNodeInfoDict[inputsRef[1]][0] + FH = filterShape[0] + FW = filterShape[1] + + paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() + + [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding( + imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr + ) + + options = {} + options[AST.PaddingKeysDict.FH] = FH + options[AST.PaddingKeysDict.FW] = FW + options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft + options[AST.PaddingKeysDict.zPadHRight] = zPadHRight + options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft + options[AST.PaddingKeysDict.zPadWRight] = zPadWRight + options[AST.PaddingKeysDict.strideH] = strideH + options[AST.PaddingKeysDict.strideW] = strideW + return ( + None, + { + curNode.getName(): AST.BOp( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + TFNodesAST.getOperatorsIdx("#"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + options, + ) + }, + ) + + # A depthwise separable convolution is equivalent to a grouped convolution + # with no. of groups = the no. of input channels (G=CI) + # This however requires a reshape of the filter. + # Regular filter is [ FH, FW, CI, CM (channel_multiplier)] + # Doing depthwise conv results in CO = CI * CM + # Grouped conv expects [FH, FW, CI/G, (CO/G)*G] + # So we reshape to [FH, FW, 1, CI * CM] + def DepthwiseConv2dNative( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + # Reshape of filter is done in simplifyGraph + img_shape = extraNodeInfoDict[inputsRef[0]][0] + in_channels = img_shape[3] # NHWC + groups = in_channels + _, nodeToAST = TFNodesAST.Conv2D( + graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict + ) + nodeToAST[curNode.getName()].options[AST.PaddingKeysDict.group] = groups + return (None, nodeToAST) + + def Conv3D( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + + stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() + assert stridesUsed[0] == 1 and stridesUsed[4] == 1 + strideD = stridesUsed[1] + strideH = stridesUsed[2] + strideW = stridesUsed[3] + + inputShape = extraNodeInfoDict[inputsRef[0]][0] + imgD = inputShape[1] + imgH = inputShape[2] + imgW = inputShape[3] + + filterShape = extraNodeInfoDict[inputsRef[1]][0] + FD = filterShape[0] + FH = filterShape[1] + FW = filterShape[2] + + paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() + + [ + zPadDLeft, + zPadDRight, + zPadHLeft, + zPadHRight, + zPadWLeft, + zPadWRight, + ] = TFNodesAST.helper_findPadding( + imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD, FD, strideD + ) + + options = {} + options[AST.PaddingKeysDict.FD] = FD + options[AST.PaddingKeysDict.FH] = FH + options[AST.PaddingKeysDict.FW] = FW + options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft + options[AST.PaddingKeysDict.zPadDRight] = zPadDRight + options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft + options[AST.PaddingKeysDict.zPadHRight] = zPadHRight + options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft + options[AST.PaddingKeysDict.zPadWRight] = zPadWRight + options[AST.PaddingKeysDict.strideD] = strideD + options[AST.PaddingKeysDict.strideH] = strideH + options[AST.PaddingKeysDict.strideW] = strideW + options[AST.PaddingKeysDict.ConvDim] = 3 + return ( + None, + { + curNode.getName(): AST.BOp( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + TFNodesAST.getOperatorsIdx("#"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + options, + ) + }, + ) + + def Conv3DBackpropInputV2( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 3 # output_shape, filter, input + + stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() + assert stridesUsed[0] == 1 and stridesUsed[4] == 1 + strideD = stridesUsed[1] + strideH = stridesUsed[2] + strideW = stridesUsed[3] + + filterShape = extraNodeInfoDict[inputsRef[1]][0] + FD = filterShape[0] + FH = filterShape[1] + FW = filterShape[2] + + inputShape = extraNodeInfoDict[inputsRef[2]][0] + inputD = inputShape[1] + inputH = inputShape[2] + inputW = inputShape[3] + + outputShape = extraNodeInfoDict[curNode.getName()][0] + outputD = outputShape[1] + outputH = outputShape[2] + outputW = outputShape[3] + + paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() + + # Important: Using outputH and outputW in the below is not an error! + # For convTranspose, the parameters passed in the node are of the conv of which this convTranspose is an inverse. + # Which is why the call to helper_findPadding makes sense. + # The zPads below are of the conv of which this convTranspose is an inverse. + [ + zPadDLeft, + zPadDRight, + zPadHLeft, + zPadHRight, + zPadWLeft, + zPadWRight, + ] = TFNodesAST.helper_findPadding( + outputH, + outputW, + FH, + FW, + strideH, + strideW, + paddingUsedStr, + imgD=outputD, + FD=FD, + strideD=strideD, + ) + + options = {} + options[AST.PaddingKeysDict.FD] = FD + options[AST.PaddingKeysDict.FH] = FH + options[AST.PaddingKeysDict.FW] = FW + options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft + options[AST.PaddingKeysDict.zPadDRight] = zPadDRight + options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft + options[AST.PaddingKeysDict.zPadHRight] = zPadHRight + options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft + options[AST.PaddingKeysDict.zPadWRight] = zPadWRight + options[AST.PaddingKeysDict.strideD] = strideD + options[AST.PaddingKeysDict.strideH] = strideH + options[AST.PaddingKeysDict.strideW] = strideW + options[AST.PaddingKeysDict.ConvDim] = 3 + options[AST.PaddingKeysDict.outputImgD] = outputD + options[AST.PaddingKeysDict.outputImgH] = outputH + options[AST.PaddingKeysDict.outputImgW] = outputW + return ( + None, + { + curNode.getName(): AST.BOp( + AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), + TFNodesAST.getOperatorsIdx("#T"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + options, + ) + }, + ) + + def helper_processPool( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + typeOfPool: str, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 1 + + options = {} + + stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() + assert (stridesUsed[0] == 1) and (stridesUsed[3] == 1) + strideH = stridesUsed[1] + strideW = stridesUsed[2] + + kSizeUsed = curNode.getAttrMapRef()["ksize"].getList().getILi() + assert (kSizeUsed[0] == 1) and (kSizeUsed[3] == 1) + kSizeH = kSizeUsed[1] + kSizeW = kSizeUsed[2] + + inputShape = extraNodeInfoDict[inputsRef[0]][0] + imgH = inputShape[1] + imgW = inputShape[2] + + paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() + [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding( + imgH, imgW, kSizeH, kSizeW, strideH, strideW, paddingUsedStr + ) + + poolType = None + if typeOfPool == "MAXPOOL": + poolType = AST.Pool.PoolType.MaxPool + elif typeOfPool == "AVGPOOL": + poolType = AST.Pool.PoolType.AvgPool + else: + print("Unknown type of pooling layer.", file=sys.stderr) + assert False + return ( + None, + { + curNode.getName(): AST.Pool( + poolType, + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + { + AST.PaddingKeysDict.FH: kSizeH, + AST.PaddingKeysDict.FW: kSizeW, + AST.PaddingKeysDict.zPadHLeft: zPadHLeft, + AST.PaddingKeysDict.zPadHRight: zPadHRight, + AST.PaddingKeysDict.zPadWLeft: zPadWLeft, + AST.PaddingKeysDict.zPadWRight: zPadWRight, + AST.PaddingKeysDict.strideH: strideH, + AST.PaddingKeysDict.strideW: strideW, + }, + ) + }, + ) + + def MaxPool( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + return TFNodesAST.helper_processPool( + graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict, "MAXPOOL" + ) + + def AvgPool( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + return TFNodesAST.helper_processPool( + graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict, "AVGPOOL" + ) + + def ConcatV2( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + N = curNode.getAttrMapRef()["N"].getI() + assert len(inputsRef) == N + 1 # One extra for axis + # TODO : Since the axis of concat is constant, therefore, its known here - the input's sizes along that dim should be + # passed as input to the below function. + # For now hardcoding. + retAST = AST.UninterpFuncCall( + extraNodeInfoDict[curNode.getName()][0], + TFNodesAST.UninterpFuncCallNames.Concat.name + str(N) + "T", + list(map(lambda x: AST.ID(dictNodeNameToOutVarStr[x]), inputsRef)), + outputDiffInpDims=1, + ) + return (None, {curNode.getName(): retAST}) + + def ExpandDims( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + retAST = AST.UninterpFuncCall( + extraNodeInfoDict[curNode.getName()][0], + TFNodesAST.UninterpFuncCallNames.ExpandDims.name, + list(map(lambda x: AST.ID(dictNodeNameToOutVarStr[x]), inputsRef)), + ) + return (None, {curNode.getName(): retAST}) + + def Slice( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 3 + beginNode = graph.__getitem__(inputsRef[1]) + sizeNode = graph.__getitem__(inputsRef[2]) + assert ( + beginNode.getAttrVal("value") is not None + ), "begin {} of Slice node {} has to be a constant".format( + inputsRef[1], curNode.getName() + ) + assert ( + sizeNode.getAttrVal("value") is not None + ), "size {} of Slice node {} has to be a constant".format( + inputsRef[2], curNode.getName() + ) + begin = beginNode.getAttrVal("value").getTensor().getContentAsValArr() + size = sizeNode.getAttrVal("value").getTensor().getContentAsValArr() + assert begin is not None + assert size is not None + assert len(begin) == len(size) + subscriptRanges = [] + for i in range(0, len(size)): + subscriptRanges.append((begin[i], begin[i] + size[i] - 1)) + + return ( + None, + { + curNode.getName(): AST.Slice( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), subscriptRanges + ) + }, + ) + + def Tile( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + return ( + None, + { + curNode.getName(): AST.UninterpFuncCall( + extraNodeInfoDict[curNode.getName()][0], + TFNodesAST.UninterpFuncCallNames.Tile.name, + [ + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + ], + ) + }, + ) + + def Sum( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + attrMapRef = curNode.getAttrMapRef() + assert len(inputsRef) == 2 + keepdims = False + if "keep_dims" in attrMapRef: + keepdims = attrMapRef["keep_dims"].getB() + + reductionAxesNodeName = inputsRef[1] + redAxesN = graph.__getitem__(reductionAxesNodeName) + redAxesT = redAxesN.getAttrVal("value").getTensor() + rank = redAxesT.getShapeRef().getRank() + if rank != 0: + reductionAxesList = redAxesT.getContentAsValArr() + else: + reductionAxesList = [redAxesT.getConstantVal()] + + curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] + return ( + None, + { + curNode.getName(): AST.Reduce( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + keepdims, + curNodeShapeLi, + TFNodesAST.getOperatorsIdx("+"), + reductionAxesList, + ) + }, + ) + + def Mean( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + attrMapRef = curNode.getAttrMapRef() + assert len(inputsRef) == 2 + keepdims = False + if "keep_dims" in attrMapRef: + keepdims = attrMapRef["keep_dims"].getB() + + reductionAxesNodeName = inputsRef[1] + redAxesN = graph.__getitem__(reductionAxesNodeName) + redAxesT = redAxesN.getAttrVal("value").getTensor() + rank = redAxesT.getShapeRef().getRank() + if rank != 0: + reductionAxesList = redAxesT.getContentAsValArr() + else: + reductionAxesList = [redAxesT.getConstantVal()] + + curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] + return ( + None, + { + curNode.getName(): AST.Reduce( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + keepdims, + curNodeShapeLi, + TFNodesAST.getOperatorsIdx("mean"), + reductionAxesList, + ) + }, + ) + + def ArgMax( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + return ( + None, + { + curNode.getName(): AST.ArgMax( + extraNodeInfoDict[curNode.getName()][0], + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + extraNodeInfoDict[inputsRef[0]][0], + ) + }, + ) + + def NoOp( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + return (None, {curNode.getName(): None}) + + def Square( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 1 + return ( + None, + { + curNode.getName(): AST.BOp( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + TFNodesAST.getOperatorsIdx(".*"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + ) + }, + ) + + def Pad( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + # Mode refers to 'CONSTANT', 'REFLECT' or 'SYMMETRIC' + mode = 0 + if "mode" in curNode.getAttrMapRef(): + mode = curNode.getAttrMapRef()["mode"].getI() + + constant_values = 0 + if "constant_values" in curNode.getAttrMapRef(): + constant_values = curNode.getAttrMapRef()["constant_values"].getI() + + assert ( + mode == 0 and constant_values == 0 + ) # For now to make life easy - deal with SYMMETRIC AND REFLECT when time comes + inputsRef = curNode.getInputsRef() + inputTensorShapeLi = extraNodeInfoDict[inputsRef[0]][0] + return ( + None, + { + curNode.getName(): AST.UninterpFuncCall( + extraNodeInfoDict[curNode.getName()][0], + TFNodesAST.UninterpFuncCallNames.Pad.name, + [ + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + ], + outputDiffInpDims=1, + ) + }, + ) + + def FusedBatchNorm( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + return ( + None, + { + curNode.getName(): AST.FusedBatchNorm( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), + ) + }, + ) + + def FusedBatchNormV3( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + return ( + None, + { + curNode.getName(): AST.FusedBatchNorm( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), + ) + }, + ) + + def Transpose( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + permNodeName = inputsRef[1] + # We need to fetch the tensor value of the perm Node + permNode = graph.__getitem__(permNodeName) + permTensor = permNode.getAttrVal("value").getTensor() + permList = permTensor.getContentAsValArr() + assert permTensor.getDType().kind == "i" + assert permTensor.getShapeRef().getRank() == 1 + return ( + None, + { + curNode.getName(): AST.Transpose( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), permList + ) + }, + ) + + def Split( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + axisNodeName = inputsRef[ + 0 + ] # split_dim input. Has to be a constant. We don't support dynamic codegen yet + axisNode = graph.__getitem__(axisNodeName) + axisTensor = axisNode.getAttrVal("value").getTensor() + axis = axisTensor.getConstantVal() + numSplits = curNode.getAttrVal("num_split").getI() + inputTensorShape = extraNodeInfoDict[inputsRef[1]][0] + assert axis < len(inputTensorShape) + assert inputTensorShape[axis] % numSplits == 0 # Should perfectly split + sizeAlongSplitDim = int(inputTensorShape[axis] / numSplits) + outputAsts = {} + for i in range(0, numSplits): + output_name = curNode.getName() + if i != 0: + output_name += ":" + str(i) + subscriptRanges = [] + for j in range(0, len(inputTensorShape)): + start = 0 + end = inputTensorShape[j] - 1 + if j == axis: + start = i * sizeAlongSplitDim + end = start + sizeAlongSplitDim - 1 + subscriptRanges.append((start, end)) + outputAsts[output_name] = AST.Slice( + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), subscriptRanges + ) + return (None, outputAsts) + + def Squeeze( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + inputTensorShape = extraNodeInfoDict[inputsRef[0]][0] + inputTensorRank = len(inputTensorShape) + + squeezeDims = curNode.getAttrMapRef()["squeeze_dims"].getList().getILi() + squeezeDimsRank = len(squeezeDims) + + return ( + None, + { + curNode.getName(): AST.UninterpFuncCall( + extraNodeInfoDict[curNode.getName()][0], + TFNodesAST.UninterpFuncCallNames.Squeeze.name, + list(map(lambda x: AST.Int(x, 32, isSecret=False), squeezeDims)) + + [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])], + ) + }, + ) + + def BiasAdd( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + assert len(inputsRef) == 2 + return ( + None, + { + curNode.getName(): AST.BOp( + AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), + TFNodesAST.getOperatorsIdx("+"), + AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), + ) + }, + ) + + def ReadVariableOp( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + return ( + None, + {curNode.getName(): AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])}, + ) + + def Softmax( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + return ( + None, + {curNode.getName(): AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])}, + ) + + def StopGradient( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + inputsRef = curNode.getInputsRef() + return ( + None, + {curNode.getName(): AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])}, + ) + + def VarHandleOp( + graph: Graph.Graph, + curNode: Graph.Node, + dictNodeNameToOutVarStr: dict, + extraNodeInfoDict: dict, + ): + return TFNodesAST.VariableV2( + graph, curNode, dictNodeNameToOutVarStr, extraNodeInfoDict + ) diff --git a/Athos/tests/conftest.py b/Athos/tests/conftest.py index 52f60356..b4a425c1 100644 --- a/Athos/tests/conftest.py +++ b/Athos/tests/conftest.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import pytest import tempfile import shutil diff --git a/Athos/tests/tf/unittests/test_arith_binops.py b/Athos/tests/tf/unittests/test_arith_binops.py index d6c415f0..b7ee6c33 100644 --- a/Athos/tests/tf/unittests/test_arith_binops.py +++ b/Athos/tests/tf/unittests/test_arith_binops.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import tensorflow as tf import numpy as np @@ -34,6 +34,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..")) from tests.utils import Config, Compiler, assert_almost_equal + @pytest.mark.parametrize( "a_shape,b_shape,dtype", [ @@ -63,12 +64,19 @@ def test_arith_binop(test_dir, backend, tfOp, a_shape, b_shape, dtype): assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) return + @pytest.mark.parametrize( "a_shape, b_shape, data_format, dtype", [ ([4, 1, 4], [4], None, np.single), # Normal - ([4, 1, 4], [4], 'N..C', np.single), # Same as above - pytest.param([4, 4, 1], [4], 'NC..', np.single, marks=pytest.mark.skip(reason="[bias_add] NC.. not supported")), # Normal + ([4, 1, 4], [4], "N..C", np.single), # Same as above + pytest.param( + [4, 4, 1], + [4], + "NC..", + np.single, + marks=pytest.mark.skip(reason="[bias_add] NC.. not supported"), + ), # Normal ], ) def test_bias_add(test_dir, backend, a_shape, b_shape, data_format, dtype): @@ -93,14 +101,43 @@ def test_bias_add(test_dir, backend, a_shape, b_shape, data_format, dtype): @pytest.mark.parametrize( "tfOp, a_val, divisor", [ - pytest.param(tf.divide, [7, -7], 5, marks=pytest.mark.skip(reason="[divide] Support for parsing DOUBLES")), # [1, -2] + pytest.param( + tf.divide, + [7, -7], + 5, + marks=pytest.mark.skip(reason="[divide] Support for parsing DOUBLES"), + ), # [1, -2] (tf.divide, [7.0, -7.0], 5.0), # [1.4, -1.4] - pytest.param(tf.truediv, [7, -7], 5, marks=pytest.mark.skip(reason="[divide] Support for parsing DOUBLES")), # [1.4, -1.4] + pytest.param( + tf.truediv, + [7, -7], + 5, + marks=pytest.mark.skip(reason="[divide] Support for parsing DOUBLES"), + ), # [1.4, -1.4] (tf.truediv, [7.0], 5.0), # [1.4] (tf.divide, 7.0, 5.0), # 1.4 - pytest.param(tf.floordiv, [7, -7], 5, marks=pytest.mark.skip(reason="[divide] Add support for converting div by constant into a mul")), # [1, -2] - pytest.param(tf.floordiv, [7.0, -7.0], 5.0, marks=pytest.mark.skip(reason="[divide] Add support for converting div by constant into a mul")), # [1.0, -2.0] - pytest.param(tf.truncatediv, -7, 5, marks=pytest.mark.skip(reason="[divide] Truncated div not supported")), # -1 + pytest.param( + tf.floordiv, + [7, -7], + 5, + marks=pytest.mark.skip( + reason="[divide] Add support for converting div by constant into a mul" + ), + ), # [1, -2] + pytest.param( + tf.floordiv, + [7.0, -7.0], + 5.0, + marks=pytest.mark.skip( + reason="[divide] Add support for converting div by constant into a mul" + ), + ), # [1.0, -2.0] + pytest.param( + tf.truncatediv, + -7, + 5, + marks=pytest.mark.skip(reason="[divide] Truncated div not supported"), + ), # -1 ], ) def test_div(test_dir, backend, tfOp, a_val, divisor, dtype): diff --git a/Athos/tests/tf/unittests/test_batchnorm.py b/Athos/tests/tf/unittests/test_batchnorm.py index 9b0e611e..b1ab6bfb 100644 --- a/Athos/tests/tf/unittests/test_batchnorm.py +++ b/Athos/tests/tf/unittests/test_batchnorm.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import tensorflow as tf import numpy as np @@ -36,14 +36,13 @@ @pytest.mark.parametrize( "a_shape, scale, offset, mean, variance", - [([1, 2, 2, 1], [1.5], [2.3], [0.5], [0.2]), - #([1], 1.5, 2.3, 0.5, 0.2), ([], 1.5, 2.3, 0.5, 0.2) + [ + ([1, 2, 2, 1], [1.5], [2.3], [0.5], [0.2]), + # ([1], 1.5, 2.3, 0.5, 0.2), ([], 1.5, 2.3, 0.5, 0.2) ], ) @pytest.mark.parametrize("dtype", [np.single]) -@pytest.mark.parametrize( - "tfOp", [tf.raw_ops.FusedBatchNorm] -) +@pytest.mark.parametrize("tfOp", [tf.raw_ops.FusedBatchNorm]) @pytest.mark.skip(reason="[batch_norm] Test not complete") def test_fused_batch_norm( test_dir, backend, tfOp, a_shape, scale, offset, mean, variance, dtype @@ -68,4 +67,4 @@ def test_fused_batch_norm( compiler = Compiler(graph, config, test_dir) mpc_output = compiler.compile_and_run([a_inp]) assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) - return \ No newline at end of file + return diff --git a/Athos/tests/tf/unittests/test_non_linear.py b/Athos/tests/tf/unittests/test_non_linear.py index 1c35d050..83da5435 100644 --- a/Athos/tests/tf/unittests/test_non_linear.py +++ b/Athos/tests/tf/unittests/test_non_linear.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import tensorflow as tf import numpy as np @@ -62,6 +62,7 @@ def test_non_linear(test_dir, backend, tfOp, a_shape, dtype): assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) return + @pytest.mark.skip(reason="[softmax] Haven't made non-linear functionalities public") @pytest.mark.parametrize("a_shape, axis", [([2, 3], 1), ([1], 0)]) @pytest.mark.parametrize("dtype", [np.single]) diff --git a/Athos/tests/tf/unittests/test_shape_manipulation.py b/Athos/tests/tf/unittests/test_shape_manipulation.py index 6bbf9547..44d23c87 100644 --- a/Athos/tests/tf/unittests/test_shape_manipulation.py +++ b/Athos/tests/tf/unittests/test_shape_manipulation.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import tensorflow as tf import numpy as np @@ -40,8 +40,8 @@ ([2, 3], [6]), ([6], [2, 3]), ([2, 3], [3, 2]), - ([2, 3], [-1]), # Flatten 1-D, - ([1], []), # convert to scalar, + ([2, 3], [-1]), # Flatten 1-D, + ([1], []), # convert to scalar, ([3, 2, 3], [2, -1]), # infer -1 as 9, ([3, 2, 3], [-1, 9]), # infer -1 as 2 ], diff --git a/Athos/tests/tf/unittests/test_unaryops.py b/Athos/tests/tf/unittests/test_unaryops.py index 87e9b77c..6477de6b 100644 --- a/Athos/tests/tf/unittests/test_unaryops.py +++ b/Athos/tests/tf/unittests/test_unaryops.py @@ -1,4 +1,4 @@ -''' +""" Authors: Pratik Bhatu. @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import tensorflow as tf import numpy as np @@ -83,7 +83,7 @@ def test_uop(test_dir, backend, tfOp, a_shape, dtype): ) @pytest.mark.parametrize("dtype", [np.single]) @pytest.mark.parametrize("tfOp", [tf.math.reduce_mean, tf.reduce_sum]) -#@pytest.mark.skip(reason="[reduce] Reduce mean output mismatch and shape failure") +# @pytest.mark.skip(reason="[reduce] Reduce mean output mismatch and shape failure") def test_reduce(test_dir, backend, tfOp, a_shape, axis, keepdims, dtype): graph = tf.Graph() a_inp = dtype(np.random.randn(*a_shape)) @@ -223,4 +223,4 @@ def test_fill(test_dir, backend, a_shape, value): compiler = Compiler(graph, config, test_dir) mpc_output = compiler.compile_and_run([], timeoutSeconds=60) assert_almost_equal(tf_output=expected_output, mpc_tensor=mpc_output, precision=2) - return \ No newline at end of file + return From 0794cfa6d4923b51e276ce059281a678a6a46c81 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Fri, 22 Jan 2021 12:02:56 +0530 Subject: [PATCH 48/72] Update .git-blame-ignore-revs --- Athos/.git-blame-ignore-revs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Athos/.git-blame-ignore-revs b/Athos/.git-blame-ignore-revs index 45713d37..9b5194cd 100644 --- a/Athos/.git-blame-ignore-revs +++ b/Athos/.git-blame-ignore-revs @@ -1 +1 @@ -da9f654919ac47e08e166b389653b4abfa69900a +d61c6a1d1b105597d6934890adb6ae89f64d868c From 16b135b06c8962aa69710da48441b8f798de4fa8 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Fri, 22 Jan 2021 16:43:13 +0530 Subject: [PATCH 49/72] Rename PORTHOS2PC to SCI --- Athos/CompileSampleNetworks.py | 24 ++-- Athos/CompileTF.sh | 6 +- Athos/CompileTFGraph.py | 20 +-- .../sample_networks/print_stats_2pc.sh | 2 +- Athos/README.md | 6 +- Athos/TFEzPCLibrary/Library32_porthos2pc.ezpc | 126 ------------------ Athos/TFEzPCLibrary/Library64_porthos2pc.ezpc | 126 ------------------ Athos/tests/utils.py | 6 +- .../{codegenporthos2pc.ml => codegensci.ml} | 16 +-- EzPC/EzPC/config.ml | 12 +- EzPC/EzPC/main.ml | 24 ++-- 11 files changed, 58 insertions(+), 310 deletions(-) delete mode 100644 Athos/TFEzPCLibrary/Library32_porthos2pc.ezpc delete mode 100644 Athos/TFEzPCLibrary/Library64_porthos2pc.ezpc rename EzPC/EzPC/{codegenporthos2pc.ml => codegensci.ml} (95%) diff --git a/Athos/CompileSampleNetworks.py b/Athos/CompileSampleNetworks.py index fa45a453..733017cd 100644 --- a/Athos/CompileSampleNetworks.py +++ b/Athos/CompileSampleNetworks.py @@ -45,7 +45,7 @@ def parse_args(): //--------------------------- Mandatory options --------------------------- "network_name":"ResNet", // Any network name from Athos/Networks directory (ResNet/DenseNet/ChestXRay/..) - "target":"PORTHOS2PC", // Compilation target. ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC + "target":"SCI", // Compilation target. ABY/CPP/CPPRING/PORTHOS/SCI @@ -54,12 +54,12 @@ def parse_args(): "bitlength":64, // Bit length to compile for. DEFAULT=64. "modulo" : 32, // Modulo to be used for shares. Applicable for - // CPPRING/PORTHOS2PC backend. For - // PORTHOS2PC + backend=OT => Power of 2 - // PORTHOS2PC + backend=HE => Prime value." + // CPPRING/SCI backend. For + // SCI + backend=OT => Power of 2 + // SCI + backend=HE => Prime value." "backend" : "OT", // Backend to be used - OT/HE (DEFAULT=OT). - // Only applicable for PORTHOS2PC backend + // Only applicable for SCI backend "disable_all_hlil_opts" : false, // Disable all optimizations in HLIL. DEFAULT=false "disable_relu_maxpool_opts" : false, // Disable Relu-Maxpool optimization. DEFAULT=false @@ -109,11 +109,11 @@ def generate_code(params, debug=False): assert bitlength <= 64 and bitlength >= 1, "Bitlen must be >= 1 and <= 64" assert target in [ "PORTHOS", - "PORTHOS2PC", + "SCI", "ABY", "CPP", "CPPRING", - ], "Target must be any of ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC" + ], "Target must be any of ABY/CPP/CPPRING/PORTHOS/SCI" cwd = os.getcwd() athos_dir = os.path.dirname(os.path.abspath(__file__)) @@ -198,7 +198,7 @@ def generate_code(params, debug=False): output_name = ezpc_file_name[:-5] + "0.cpp" if modulo is not None: ezpc_args += "--modulo {} ".format(modulo) - if target == "PORTHOS2PC": + if target == "SCI": ezpc_args += "--backend {} ".format(backend.upper()) output_name = ezpc_file_name[:-5] + "_{}0.cpp".format(backend.upper()) if target in ["PORTHOS"]: @@ -222,7 +222,7 @@ def generate_code(params, debug=False): print( "--------------------------------------------------------------------------------" ) - if target == "PORTHOS2PC": + if target == "SCI": program_name = model_base_name + "_" + target + "_" + backend + ".out" else: program_name = model_base_name + "_" + target + ".out" @@ -261,7 +261,7 @@ def generate_code(params, debug=False): print( "Not compiling generated code. Please follow the readme and build Porthos." ) - elif target == "PORTHOS2PC": + elif target == "SCI": sci = os.path.join(athos_dir, "..", "SCI") sci_src = os.path.join(sci, "src") sci_lib = os.path.join(sci, "build", "lib") @@ -305,7 +305,7 @@ def generate_code(params, debug=False): print( "--------------------------------------------------------------------------------" ) - mode = target + " - " + backend if target == "PORTHOS2PC" else target + mode = target + " - " + backend if target == "SCI" else target print("Running program securely in {} mode".format(mode)) print( "--------------------------------------------------------------------------------" @@ -318,7 +318,7 @@ def generate_code(params, debug=False): run_script_path = os.path.join(sample_networks_dir, "run_demo_cpp.sh") elif target == "PORTHOS": run_script_path = os.path.join(sample_networks_dir, "run_demo_3pc.sh") - elif target == "PORTHOS2PC": + elif target == "SCI": run_script_path = os.path.join(sample_networks_dir, "run_demo_2pc.sh") os.system( "{script} {model_dir} {model_binary} {model_input} {model_weight}".format( diff --git a/Athos/CompileTF.sh b/Athos/CompileTF.sh index eadb27dd..4af393ba 100755 --- a/Athos/CompileTF.sh +++ b/Athos/CompileTF.sh @@ -38,10 +38,10 @@ usage() { echo -e "CrypTFlow compilation script. Options:"; echo -e "<-b|--bitlen> :: Bit length to compile for. Defaults to 64"; echo -e "<-s|--scaling-fac> :: Scaling factor to compile for. Defaults to 12."; - echo -e "<-t|--target> :: Compilation target. Possible options: ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC. Defaults to CPP."; + echo -e "<-t|--target> :: Compilation target. Possible options: ABY/CPP/CPPRING/PORTHOS/SCI. Defaults to CPP."; echo -e "<-f|--filename> :: Python tensorflow file to compile." - echo -e "<--modulo> :: Modulo to be used for shares. Applicable for CPPRING/PORTHOS2PC backend. For PORTHOS2PC, for backend type OT, this should be power of 2 and for backend type HE, this should be a prime." - echo -e "<--backend> :: Backend to be used - OT/HE (default OT). Applicable for PORTHOS2PC backend." + echo -e "<--modulo> :: Modulo to be used for shares. Applicable for CPPRING/SCI backend. For SCI, for backend type OT, this should be power of 2 and for backend type HE, this should be a prime." + echo -e "<--backend> :: Backend to be used - OT/HE (default OT). Applicable for SCI backend." echo -e "<--disable-hlil-all-opti> :: Disable all optimizations in HLIL." echo -e "<--disable-rmo> :: Disable Relu-Maxpool optimization." echo -e "<--disable-liveness-opti> :: Disable Liveness Optimization." diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index b5d6605f..8d010b43 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -50,7 +50,7 @@ def parse_args(): "output1", "output2" ], - "target":"PORTHOS2PC", // Compilation target. ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC + "target":"SCI", // Compilation target. ABY/CPP/CPPRING/PORTHOS/SCI @@ -64,12 +64,12 @@ def parse_args(): "input2":"2,245,234,3" // placeholder nodes have shape info in the .pb file. }, "modulo" : 32, // Modulo to be used for shares. Applicable for - // CPPRING/PORTHOS2PC backend. For - // PORTHOS2PC + backend=OT => Power of 2 - // PORTHOS2PC + backend=HE => Prime value." + // CPPRING/SCI backend. For + // SCI + backend=OT => Power of 2 + // SCI + backend=HE => Prime value." "backend" : "OT", // Backend to be used - OT/HE (default OT). - // Only applicable for PORTHOS2PC backend + // Only applicable for SCI backend "disable_all_hlil_opts" : false, // Disable all optimizations in HLIL. DEFAULT=false "disable_relu_maxpool_opts" : false, // Disable Relu-Maxpool optimization. DEFAULT=false @@ -113,11 +113,11 @@ def generate_code(params, debug=False): assert bitlength <= 64 and bitlength >= 1, "Bitlen must be >= 1 and <= 64" assert target in [ "PORTHOS", - "PORTHOS2PC", + "SCI", "ABY", "CPP", "CPPRING", - ], "Target must be any of ABY/CPP/CPPRING/PORTHOS/PORTHOS2PC" + ], "Target must be any of ABY/CPP/CPPRING/PORTHOS/SCI" cwd = os.getcwd() athos_dir = os.path.dirname(os.path.abspath(__file__)) @@ -192,7 +192,7 @@ def generate_code(params, debug=False): output_name = ezpc_file_name[:-5] + "0.cpp" if modulo is not None: ezpc_args += "--modulo {} ".format(modulo) - if target == "PORTHOS2PC": + if target == "SCI": ezpc_args += "--backend {} ".format(backend.upper()) output_name = ezpc_file_name[:-5] + "_{}0.cpp".format(backend.upper()) if target in ["PORTHOS"]: @@ -210,7 +210,7 @@ def generate_code(params, debug=False): output_file = os.path.join(model_abs_dir, output_name) print("Compiling generated code to {target} target".format(target=target)) - if target == "PORTHOS2PC": + if target == "SCI": program_name = model_base_name + "_" + target + "_" + backend + ".out" else: program_name = model_base_name + "_" + target + ".out" @@ -246,7 +246,7 @@ def generate_code(params, debug=False): print( "Not compiling generated code. Please follow the readme and build Porthos." ) - elif target == "PORTHOS2PC": + elif target == "SCI": sci = os.path.join(athos_dir, "..", "SCI") sci_src = os.path.join(sci, "src") sci_lib = os.path.join(sci, "build", "lib") diff --git a/Athos/CompilerScripts/sample_networks/print_stats_2pc.sh b/Athos/CompilerScripts/sample_networks/print_stats_2pc.sh index 5236c617..7834b374 100755 --- a/Athos/CompilerScripts/sample_networks/print_stats_2pc.sh +++ b/Athos/CompilerScripts/sample_networks/print_stats_2pc.sh @@ -38,7 +38,7 @@ echo "-------------------------------------------------------" if [ $PARTY -eq 1 ]; then echo "Model outputs:" - echo -e "MPC PORTHOS (2PC) output:\t $(awk '$0==($0+0)' ${MODEL_DIR}/party${PARTY}_mpc_output.out)" + echo -e "MPC SCI (2PC) output:\t $(awk '$0==($0+0)' ${MODEL_DIR}/party${PARTY}_mpc_output.out)" echo -e "Tensorflow output:\t\t $(cat ${MODEL_DIR}/tf_pred.float)" echo "" fi diff --git a/Athos/README.md b/Athos/README.md index 387b0c96..ee602b90 100644 --- a/Athos/README.md +++ b/Athos/README.md @@ -42,7 +42,7 @@ The script takes a config file as input. The contents of the config are: - *network_name*: Can be any network in the Networks directory. - *target*: This is the secure protocol the model will run in. The possible values are: - **PORTHOS**: The semi-honest 3PC protocol. - - **PORTHOS2PC**: The semi-honest 2PC protocol in SCI. + - **SCI**: The semi-honest 2PC protocol in SCI. - **CPP**: A non-secure debug backend which outputs plain C++ to test for correctness. - *run_in_tmux*: If true, the script spawns a tmux session to run the network. There is a terminal pane for each party. You can modify the config file according to which network and backend you want to compile for. See ```python CompileSampleNetworks.py --help``` for more information about the parameters of the config file. @@ -80,10 +80,10 @@ To run the network in **2PC mode (SCI)**, open 2 terminals and do the following - Party 0 [Server]: - ``` ./Networks/ResNet/ResNet_PORTHOS2PC_OT.out r=1 p=12345 < model_weights_scale_12.inp" ``` + ``` ./Networks/ResNet/ResNet_SCI_OT.out r=1 p=12345 < model_weights_scale_12.inp" ``` - Party 1 [Client]: - ``` ./Networks/ResNet/ResNet_PORTHOS2PC_OT.out r=2 ip=127.0.0.1 p=12345 < model_input_scale_12.inp" ``` + ``` ./Networks/ResNet/ResNet_SCI_OT.out r=2 ip=127.0.0.1 p=12345 < model_input_scale_12.inp" ``` To run the network in **CPP mode (1PC-debug-non-secure)**, open a terminal and do the following: - ``` ./Networks/ResNet/ResNet_CPP.out < <(cat model_input_scale_12.inp model_weights_scale_12.inp) ``` diff --git a/Athos/TFEzPCLibrary/Library32_porthos2pc.ezpc b/Athos/TFEzPCLibrary/Library32_porthos2pc.ezpc deleted file mode 100644 index 8d6197f7..00000000 --- a/Athos/TFEzPCLibrary/Library32_porthos2pc.ezpc +++ /dev/null @@ -1,126 +0,0 @@ -(* - -Authors: Nishant Kumar. - -Copyright: -Copyright (c) 2020 Microsoft Research -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -*) - -(**************************) -extern void MatMul2D(int32_pl i, int32_pl j, int32_pl k, int32_al[i][j] A, int32_al[j][k] B, int32_al[i][k] C, bool_pl modelIsA); - -(**************************) -extern void ArgMax(int32_pl s1, int32_pl s2, int32_al[s1][s2] inArr, int32_al[s1] outArr); - -(**************************) -extern void Relu(int32_pl s1, int32_al[s1] inArr, int32_al[s1] outArr, int32_pl sf, bool_pl doTruncation); - -(**************************) -extern void Floor(int32_pl s1, int32_al[s1] inArr, int32_al[s1] outArr, int32_pl sf); - -(**************************) -(* int32_al[N][H][W][C] input *) -extern void MaxPool(int32_pl N, int32_pl H, int32_pl W, int32_pl C, - int32_pl ksizeH, int32_pl ksizeW, - int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideH, int32_pl strideW, - int32_pl N1, int32_pl imgH, int32_pl imgW, int32_pl C1, - int32_al[N1][imgH][imgW][C1] inArr, - int32_al[N][H][W][C] outArr); - -(**************************) -(* int32_al[N][H][W][C] input *) -extern void AvgPool(int32_pl N, int32_pl H, int32_pl W, int32_pl C, - int32_pl ksizeH, int32_pl ksizeW, - int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideH, int32_pl strideW, - int32_pl N1, int32_pl imgH, int32_pl imgW, int32_pl C1, - int32_al[N1][imgH][imgW][C1] inArr, - int32_al[N][H][W][C] outArr); - -(**************************) -extern void ElemWiseSecretSharedVectorMult(int32_pl s1, int32_al[s1] arr1, int32_al[s1] arr2, int32_al[s1] outArr); -extern void ElemWiseActModelVectorMult(int32_pl s1, int32_al[s1] arr1, int32_al[s1] arr2, int32_al[s1] outArr); -extern void ElemWiseVectorPublicDiv(int32_pl s1, int32_al[s1] arr1, int32_pl divisor, int32_al[s1] outArr); - -(**************************) -extern void ScaleUp(int32_pl s1, int32_al[s1] arr, int32_pl sf); - -(**************************) -extern void ScaleDown(int32_pl s1, int32_al[s1] arr, int32_pl sf); - -(**************************) -extern void ClearMemSecret1(int32_pl s1, int32_al[s1] arr); -extern void ClearMemSecret2(int32_pl s1, int32_pl s2, int32_al[s1][s2] arr); -extern void ClearMemSecret3(int32_pl s1, int32_pl s2, int32_pl s3, int32_al[s1][s2][s3] arr); -extern void ClearMemSecret4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] arr); -extern void ClearMemSecret5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[s1][s2][s3][s4][s5] arr); -extern void ClearMemPublic(int32_pl x); -extern void ClearMemPublic1(int32_pl s, int32_pl[s] x); -extern void ClearMemPublic2(int32_pl s1, int32_pl s2, int32_pl[s1][s2] arr); -extern void ClearMemPublic3(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl[s1][s2][s3] arr); -extern void ClearMemPublic4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl[s1][s2][s3][s4] arr); -extern void ClearMemPublic5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl[s1][s2][s3][s4][s5] arr); - -(**************************) -extern void StartComputation(); -extern void EndComputation(); - -(**************************) -extern void Conv2DWrapper(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, - int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideH, int32_pl strideW, - int32_al[N][H][W][CI] inputArr, - int32_al[FH][FW][CI][CO] filterArr, - int32_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr); - -extern void Conv3DWrapper(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, - int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideD, int32_pl strideH, int32_pl strideW, - int32_al[N][D][H][W][CI] inputArr, - int32_al[FD][FH][FW][CI][CO] filterArr, - int32_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr); - -extern void Conv2DGroupWrapper(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, - int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideH, int32_pl strideW, int32_pl G, - int32_al[N][H][W][CI] inputArr, - int32_al[FH][FW][CI/G][CO] filterArr, - int32_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr); - -extern void ConvTranspose2DWrapper(int32_pl N, int32_pl HPrime, int32_pl WPrime, int32_pl CI, - int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl H, int32_pl W, - int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, - int32_pl strideH, int32_pl strideW, - int32_al[N][HPrime][WPrime][CI] inputArr, - int32_al[FH][FW][CO][CI] filterArr, - int32_al[N][H][W][CO] outArr); - -extern void ConvTranspose3DWrapper(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, - int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl D, int32_pl H, int32_pl W, - int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, - int32_pl strideD, int32_pl strideH, int32_pl strideW, - int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, - int32_al[FD][FH][FW][CO][CI] filterArr, - int32_al[N][D][H][W][CO] outArr); diff --git a/Athos/TFEzPCLibrary/Library64_porthos2pc.ezpc b/Athos/TFEzPCLibrary/Library64_porthos2pc.ezpc deleted file mode 100644 index 5a1a6b35..00000000 --- a/Athos/TFEzPCLibrary/Library64_porthos2pc.ezpc +++ /dev/null @@ -1,126 +0,0 @@ -(* - -Authors: Nishant Kumar. - -Copyright: -Copyright (c) 2020 Microsoft Research -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -*) - -(**************************) -extern void MatMul2D(int32_pl i, int32_pl j, int32_pl k, int64_al[i][j] A, int64_al[j][k] B, int64_al[i][k] C, bool_pl modelIsA); - -(**************************) -extern void ArgMax(int32_pl s1, int32_pl s2, int64_al[s1][s2] inArr, int64_al[s1] outArr); - -(**************************) -extern void Relu(int32_pl s1, int64_al[s1] inArr, int64_al[s1] outArr, int32_pl sf, bool_pl doTruncation); - -(**************************) -extern void Floor(int32_pl s1, int64_al[s1] inArr, int64_al[s1] outArr, int32_pl sf); - -(**************************) -(* int64_al[N][H][W][C] input *) -extern void MaxPool(int32_pl N, int32_pl H, int32_pl W, int32_pl C, - int32_pl ksizeH, int32_pl ksizeW, - int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideH, int32_pl strideW, - int32_pl N1, int32_pl imgH, int32_pl imgW, int32_pl C1, - int64_al[N1][imgH][imgW][C1] inArr, - int64_al[N][H][W][C] outArr); - -(**************************) -(* int64_al[N][H][W][C] input *) -extern void AvgPool(int32_pl N, int32_pl H, int32_pl W, int32_pl C, - int32_pl ksizeH, int32_pl ksizeW, - int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideH, int32_pl strideW, - int32_pl N1, int32_pl imgH, int32_pl imgW, int32_pl C1, - int64_al[N1][imgH][imgW][C1] inArr, - int64_al[N][H][W][C] outArr); - -(**************************) -extern void ElemWiseSecretSharedVectorMult(int32_pl s1, int64_al[s1] arr1, int64_al[s1] arr2, int64_al[s1] outArr); -extern void ElemWiseActModelVectorMult(int32_pl s1, int64_al[s1] arr1, int64_al[s1] arr2, int64_al[s1] outArr); -extern void ElemWiseVectorPublicDiv(int32_pl s1, int64_al[s1] arr1, int32_pl divisor, int64_al[s1] outArr); - -(**************************) -extern void ScaleUp(int32_pl s1, int64_al[s1] arr, int32_pl sf); - -(**************************) -extern void ScaleDown(int32_pl s1, int64_al[s1] arr, int32_pl sf); - -(**************************) -extern void ClearMemSecret1(int32_pl s1, int64_al[s1] arr); -extern void ClearMemSecret2(int32_pl s1, int32_pl s2, int64_al[s1][s2] arr); -extern void ClearMemSecret3(int32_pl s1, int32_pl s2, int32_pl s3, int64_al[s1][s2][s3] arr); -extern void ClearMemSecret4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64_al[s1][s2][s3][s4] arr); -extern void ClearMemSecret5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[s1][s2][s3][s4][s5] arr); -extern void ClearMemPublic(int32_pl x); -extern void ClearMemPublic1(int32_pl s, int32_pl[s] x); -extern void ClearMemPublic2(int32_pl s1, int32_pl s2, int32_pl[s1][s2] arr); -extern void ClearMemPublic3(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl[s1][s2][s3] arr); -extern void ClearMemPublic4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl[s1][s2][s3][s4] arr); -extern void ClearMemPublic5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl[s1][s2][s3][s4][s5] arr); - -(**************************) -extern void StartComputation(); -extern void EndComputation(); - -(**************************) -extern void Conv2DWrapper(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, - int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideH, int32_pl strideW, - int64_al[N][H][W][CI] inputArr, - int64_al[FH][FW][CI][CO] filterArr, - int64_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr); - -extern void Conv3DWrapper(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, - int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideD, int32_pl strideH, int32_pl strideW, - int64_al[N][D][H][W][CI] inputArr, - int64_al[FD][FH][FW][CI][CO] filterArr, - int64_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr); - -extern void Conv2DGroupWrapper(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, - int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, - int32_pl strideH, int32_pl strideW, int32_pl G, - int64_al[N][H][W][CI] inputArr, - int64_al[FH][FW][CI/G][CO] filterArr, - int64_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr); - -extern void ConvTranspose2DWrapper(int32_pl N, int32_pl HPrime, int32_pl WPrime, int32_pl CI, - int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl H, int32_pl W, - int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, - int32_pl strideH, int32_pl strideW, - int64_al[N][HPrime][WPrime][CI] inputArr, - int64_al[FH][FW][CO][CI] filterArr, - int64_al[N][H][W][CO] outArr); - -extern void ConvTranspose3DWrapper(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, - int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, - int32_pl D, int32_pl H, int32_pl W, - int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, - int32_pl strideD, int32_pl strideH, int32_pl strideW, - int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, - int64_al[FD][FH][FW][CO][CI] filterArr, - int64_al[N][D][H][W][CO] outArr); diff --git a/Athos/tests/utils.py b/Athos/tests/utils.py index 9086685f..e0a3c280 100644 --- a/Athos/tests/utils.py +++ b/Athos/tests/utils.py @@ -49,13 +49,13 @@ def __init__(self, mode): elif mode == "3PC": self.config["target"] = "PORTHOS" elif mode == "2PC_OT": - self.config["target"] = "PORTHOS2PC" + self.config["target"] = "SCI" self.config["bitlength"] = 41 self.config["scale"] = 12 self.config["backend"] = "OT" elif mode == "2PC_HE": - self.config["target"] = "PORTHOS2PC" + self.config["target"] = "SCI" self.config["bitlength"] = 41 self.config["scale"] = 12 self.config["backend"] = "HE" @@ -181,7 +181,7 @@ def run(self, inputs, timeoutSeconds): p.wait(timeoutSeconds) except subprocess.TimeoutExpired: p.kill() - elif self.target == "PORTHOS2PC": + elif self.target == "SCI": util_dir = os.path.dirname(os.path.abspath(__file__)) sci_dir = os.path.join(util_dir, "..", "..", "SCI") port = 1234 diff --git a/EzPC/EzPC/codegenporthos2pc.ml b/EzPC/EzPC/codegensci.ml similarity index 95% rename from EzPC/EzPC/codegenporthos2pc.ml rename to EzPC/EzPC/codegensci.ml index 71fc6535..b598e61f 100644 --- a/EzPC/EzPC/codegenporthos2pc.ml +++ b/EzPC/EzPC/codegensci.ml @@ -124,8 +124,8 @@ let o_subsumption (src:label) (tgt:secret_label) (t:typ) (arg:comp) :comp = | Secret Boolean -> failwith "Codegen: Subsumption from secrets is not allowed for this backend." let o_basetyp (t:base_type) :comp = - let uint32_basetype_str :string = if Config.get_porthos2pc_backend () = OT then "uint32_t" else "uint64_t" in - let int32_basetype_str :string = if Config.get_porthos2pc_backend () = OT then "int32_t" else "int64_t" in + let uint32_basetype_str :string = if Config.get_sci_backend () = OT then "uint32_t" else "uint64_t" in + let int32_basetype_str :string = if Config.get_sci_backend () = OT then "int32_t" else "int64_t" in match t with | UInt32 -> o_str uint32_basetype_str | UInt64 -> o_str "uint64_t" @@ -165,7 +165,7 @@ let rec o_secret_binop (g:gamma) (op:binop) (sl:secret_label) (e1:expr) (e2:expr and o_expr (g:gamma) (e:expr) :comp = let o_expr = o_expr g in let o_codegen_expr = o_codegen_expr g in - let uint32_basetype_str :string = if Config.get_porthos2pc_backend () = OT then "uint32_t" else "uint64_t" in + let uint32_basetype_str :string = if Config.get_sci_backend () = OT then "uint32_t" else "uint64_t" in let rec o_array_read_rec (ga:gamma) (ea:expr) : (comp*comp*int) = match ea.data with | Array_read (e1,e2) -> @@ -246,7 +246,7 @@ and o_codegen_expr (g:gamma) (e:codegen_expr) :comp = | Clear_val _ -> failwith ("Codegen_expr Clear_val is unsupported by this backend.") let rec o_typ_rec (t:typ) :comp = - let uint32_basetype_str :string = if Config.get_porthos2pc_backend () = OT then "uint32_t" else "uint64_t" in + let uint32_basetype_str :string = if Config.get_sci_backend () = OT then "uint32_t" else "uint64_t" in match t.data with | Base (Int64, Some (Secret _)) -> o_str "uint64_t" | Base (Int32, Some (Secret _)) -> o_str uint32_basetype_str @@ -270,7 +270,7 @@ let o_array_init (g:gamma) (t:typ) :comp = o_app s (List.map (o_expr g) l) let o_for (index:comp) (lower:comp) (upper:comp) (body:comp) :comp = - let uint32_basetype_str :string = if Config.get_porthos2pc_backend () = OT then "uint32_t" else "uint64_t" in + let uint32_basetype_str :string = if Config.get_sci_backend () = OT then "uint32_t" else "uint64_t" in let init = seq (o_str ("for (" ^ uint32_basetype_str ^ " ")) (seq index (seq (o_str " = ") lower)) in let term = seq index (seq (o_str " < ") upper) in let incr = seq index (o_str "++)") in @@ -685,14 +685,14 @@ let o_one_program ((globals, main):global list * codegen_stmt) (ofname:string) : let (hash_define_str, main_prelude) = let modulo_str = Config.get_modulo () |> Uint64.to_string in let hash_define_str = - if (Config.get_porthos2pc_backend () = OT) then "SCI_OT" + if (Config.get_sci_backend () = OT) then "SCI_OT" else "SCI_HE" in let main_prelude = - if (Config.get_porthos2pc_backend () = OT) then begin + if (Config.get_sci_backend () = OT) then begin (* OT case *) if Config.get_modulo () = Uint64.shift_left (Uint64.of_int 1) (Config.get_actual_bitlen ()) then "" - else failwith "Modulo can only be (1< snd |> List.length) ^ " partition(s), generating .cpp files"); if Config.get_codegen () = OBLIVC then Codegenoblivc.o_program p file else if Config.get_codegen () = PORTHOS then Codegenporthos.o_program p file - else if Config.get_codegen () = PORTHOS2PC then Codegenporthos2pc.o_program p file + else if Config.get_codegen () = SCI then Codegensci.o_program p file else if Config.get_codegen () = CPPRING then Codegencppring.o_program p file else Codegen.o_program p file; Well_typed ()) in @@ -119,10 +119,10 @@ let specs = Arg.align [ | "CPP" -> CPP |> Config.set_codegen | "OBLIVC" -> OBLIVC |> Config.set_codegen | "PORTHOS" -> PORTHOS |> Config.set_codegen - | "PORTHOS2PC" -> PORTHOS2PC |> Config.set_codegen + | "SCI" -> SCI |> Config.set_codegen | "CPPRING" -> CPPRING |> Config.set_codegen | _ -> failwith "Invalid codegen mode"), - " Codegen mode (ABY or CPP or OBLIVC or PORTHOS or PORTHOS2PC or CPPRING, default ABY)"); + " Codegen mode (ABY or CPP or OBLIVC or PORTHOS or SCI or CPPRING, default ABY)"); ("--o_prefix", Arg.String (fun s -> o_prefix := s), " Prefix for output files, default is the input file prefix"); ("--disable-tac", Arg.Unit Config.disable_tac, " Disable 3-address code transformation (also disables the CSE optimization)"); ("--disable-cse", Arg.Unit Config.disable_cse, " Disable Common Subexpression Elimination optimization"); @@ -135,13 +135,13 @@ let specs = Arg.align [ ("--shares_dir", Arg.String (fun s -> Config.set_shares_dir s), " Directory where share files should be created"); ("--debug_partitions", Arg.Unit Config.set_debug_partitions, " Debug partitions (if codegen is ABY then dump shares in clear, if codegen is CPP then generate partitions)"); ("--modulo", Arg.String Config.set_modulo, - "Modulo to be used for shares. Applicable for CPPRING/PORTHOS2PC backend. - For PORTHOS2PC, for backend type OT, this should be power of 2 and for backend type HE, this should be a prime."); + "Modulo to be used for shares. Applicable for CPPRING/SCI backend. + For SCI, for backend type OT, this should be power of 2 and for backend type HE, this should be a prime."); ("--backend", Arg.String (fun s -> match s with - | "OT" -> OT |> Config.set_porthos2pc_backend - | "HE" -> HE |> Config.set_porthos2pc_backend + | "OT" -> OT |> Config.set_sci_backend + | "HE" -> HE |> Config.set_sci_backend | _ -> failwith "Invalid backend type"), - "Porthos2PC Backend Type (OT or HE, default OT)."); + "SCI Backend Type (OT or HE, default OT)."); ("--sf", Arg.Int Config.set_sf, "Scale factor to be used in compilation. Valid only for PORTHOS."); ] let _ = @@ -156,9 +156,9 @@ let _ = if Config.get_codegen () <> PORTHOS && Config.get_sf () <> 0 then failwith "sf only valid for PORTHOS."; - if Config.get_codegen () <> CPPRING && Config.get_codegen () <> PORTHOS2PC + if Config.get_codegen () <> CPPRING && Config.get_codegen () <> SCI && (Config.get_modulo () <> (Uint64.shift_left (Uint64.of_int 1) (Config.get_bitlen ())) || (Config.get_bitlen () <> 32 && Config.get_bitlen () <> 64)) - then failwith "Modulo and {bitlen not equal to 32/64} only supported for CPPRING/PORTHOS2PC backend."; + then failwith "Modulo and {bitlen not equal to 32/64} only supported for CPPRING/SCI backend."; if Config.get_codegen () = CPPRING && (Config.get_bitlen () = 64 || Config.get_bitlen () = 32) && Config.get_modulo () = (Uint64.shift_left (Uint64.of_int 1) (Config.get_bitlen ())) @@ -176,8 +176,8 @@ let _ = if !o_prefix = "" then o_prefix := prefix; let _ = - if Config.get_codegen () = PORTHOS2PC then begin - let backend_type = if Config.get_porthos2pc_backend () = OT then "OT" else "HE" in + if Config.get_codegen () = SCI then begin + let backend_type = if Config.get_sci_backend () = OT then "OT" else "HE" in o_prefix := !o_prefix ^ "_" ^ backend_type; end in From a28334f93a5aba75e72a46db58ccc23564308623 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 3 Feb 2021 15:15:22 +0530 Subject: [PATCH 50/72] Handle numpy array of 0 volume Sometimes shape can be (x,y,0,z). And the final numpy array will be an empty array only. --- Athos/TFCompiler/DumpTFMtData.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Athos/TFCompiler/DumpTFMtData.py b/Athos/TFCompiler/DumpTFMtData.py index c6afb432..61a42e31 100644 --- a/Athos/TFCompiler/DumpTFMtData.py +++ b/Athos/TFCompiler/DumpTFMtData.py @@ -240,7 +240,7 @@ def save_weights(optimized_graph_def, sess, feed_dict, filename, scaling_factor) values = sess.run(graph_vars, feed_dict) with open(filename, "w") as ff: for val in values: - if val.shape == (0,): # Empty array, nothing to dump. + if val.shape.count(0) > 0: # Empty array, nothing to dump. continue for xx in numpy.nditer(val, order="C"): ff.write(str(int(xx * (1 << scaling_factor))) + " ") From 2e10377f2952d79d846e6104d458faad5dcb16a9 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 3 Feb 2021 15:38:04 +0530 Subject: [PATCH 51/72] Add script to parse scaled output produced. The script scales down the output, fixes the sign, and dumps it as a numpy array --- Athos/CompileTFGraph.py | 2 ++ Athos/CompilerScripts/get_output.py | 36 +++++++++++++++++++++++++++++ Athos/tests/utils.py | 20 ++-------------- 3 files changed, 40 insertions(+), 18 deletions(-) create mode 100644 Athos/CompilerScripts/get_output.py diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index 8d010b43..27a10bf7 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -273,6 +273,8 @@ def generate_code(params, debug=False): ) os.chdir(cwd) + print("Generated binary: {}".format(program_path)) + print("Use as input to server (model weights): {}".format(weights_path)) return (program_path, weights_path) diff --git a/Athos/CompilerScripts/get_output.py b/Athos/CompilerScripts/get_output.py new file mode 100644 index 00000000..d7ca78cb --- /dev/null +++ b/Athos/CompilerScripts/get_output.py @@ -0,0 +1,36 @@ +import re +import numpy as np +import sys +import os +import parse_config + +def convert_raw_output_to_np(filename, bitlength, scale): + matcher = re.compile(r"[-]?[0-9]+") + scaled_array = [] + with open(filename, "r") as f: + for line in f: + match = matcher.fullmatch(line.rstrip()) + if match: + unsigned_number = int(match.group(0)) + number = ( + unsigned_number + if (unsigned_number < 2 ** (bitlength - 1)) + else unsigned_number - 2 ** bitlength + ) + scaled_array.append(float(number) / (2 ** scale)) + return np.array(scaled_array) + + +if __name__ == "__main__": + if len(sys.argv) < 3: + print("Usage: python get_output.py mpc_output.txt config.json") + output_fname = sys.argv[1] + config_name = sys.argv[2] + params = parse_config.get_params(config_name) + scale = params["scale"] + bitlength = params["bitlength"] + bitlength = 64 if bitlength is None else bitlength + model_name = os.path.splitext(params["model_name"])[0] + np_arr = convert_raw_output_to_np(output_fname, bitlength, scale) + np.save(model_name + "_output", np_arr) + print("Output dumped as np array in " + model_name + "_output.npy") diff --git a/Athos/tests/utils.py b/Athos/tests/utils.py index e0a3c280..9b05da09 100644 --- a/Athos/tests/utils.py +++ b/Athos/tests/utils.py @@ -28,8 +28,9 @@ import re sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -import CompilerScripts.parse_config as parse_config import CompileTFGraph +import CompilerScripts.parse_config as parse_config +from CompilerScripts.get_output import convert_raw_output_to_np import numpy as np import subprocess @@ -105,23 +106,6 @@ def save_graph(graph_def, config, test_dir): return -def convert_raw_output_to_np(filename, bitlength, scale): - matcher = re.compile(r"[-]?[0-9]+") - scaled_array = [] - with open(filename, "r") as f: - for line in f: - match = matcher.fullmatch(line.rstrip()) - if match: - unsigned_number = int(match.group(0)) - number = ( - unsigned_number - if (unsigned_number < 2 ** (bitlength - 1)) - else unsigned_number - 2 ** bitlength - ) - scaled_array.append(float(number) / (2 ** scale)) - return np.array(scaled_array) - - class Program: def __init__(self, program_path, model_weight_path, params, test_dir): self.program_path = program_path From 47b08c5e439449cfa77831037298c7519cb03216 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 3 Feb 2021 15:46:45 +0530 Subject: [PATCH 52/72] Script to convert SavedModel tensorflow graphs to frozen graphs. --- .../convert_saved_model_to_frozen_graph.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 Athos/CompilerScripts/convert_saved_model_to_frozen_graph.py diff --git a/Athos/CompilerScripts/convert_saved_model_to_frozen_graph.py b/Athos/CompilerScripts/convert_saved_model_to_frozen_graph.py new file mode 100644 index 00000000..d2c224a9 --- /dev/null +++ b/Athos/CompilerScripts/convert_saved_model_to_frozen_graph.py @@ -0,0 +1,46 @@ +""" + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +""" +import sys +import tensorflow as tf +from tensorflow.python.framework.convert_to_constants import ( + convert_variables_to_constants_v2, +) + +tf.enable_eager_execution() + +if __name__ == "__main__": + if (len(sys.argv) != 2): + print("Usage: python convert_saved_model_to_frozen_graph.py saved_model_directory") + model_dir = sys.argv[1] + saved_model = tf.saved_model.load_v2(model_dir) + model = saved_model.signatures["serving_default"] + full_model = tf.function(lambda x: model(x)) + full_model = full_model.get_concrete_function( + tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype) + ) + frozen_function = convert_variables_to_constants_v2(full_model) + gd = frozen_function.graph.as_graph_def() + with tf.io.gfile.GFile(model_dir + "/frozen_graph.pb", "wb") as f: + f.write(gd.SerializeToString()) + print("Frozen graph saved in " + model_dir + "/frozen_graph.pb") From bef758477afff2845ca6ec9da6ef2f283cad2272 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 3 Feb 2021 16:21:40 +0530 Subject: [PATCH 53/72] Script to remove nodes from a tensorflow graph Useful only for removing output nodes. If you remove arbitrary nodes, it will produce broken graphs. --- Athos/CompilerScripts/parse_config.py | 1 + Athos/CompilerScripts/remove_tf_nodes.py | 64 ++++++++++++++++++++++++ Athos/CompilerScripts/tf_graph_io.py | 7 ++- 3 files changed, 70 insertions(+), 2 deletions(-) create mode 100755 Athos/CompilerScripts/remove_tf_nodes.py diff --git a/Athos/CompilerScripts/parse_config.py b/Athos/CompilerScripts/parse_config.py index 68a0f5bc..d55c85eb 100644 --- a/Athos/CompilerScripts/parse_config.py +++ b/Athos/CompilerScripts/parse_config.py @@ -24,6 +24,7 @@ import argparse import os.path import json +from json import JSONDecodeError import sys """ diff --git a/Athos/CompilerScripts/remove_tf_nodes.py b/Athos/CompilerScripts/remove_tf_nodes.py new file mode 100755 index 00000000..7e254c61 --- /dev/null +++ b/Athos/CompilerScripts/remove_tf_nodes.py @@ -0,0 +1,64 @@ +""" + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +""" + +import tensorflow as tf +import sys +import json +from json import JSONDecodeError +from tf_graph_io import load_graph_def_pb, dump_graph_def_pb + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print( + """Usage: python remove_node.py config.json +config.json should have the following fields. +{ + "model_name" : "model.pb", + "nodes_to_remove" : ["loss", "model_outputs"] +} + +""" + ) + sys.exit() + config_path = sys.argv[1] + with open(config_path) as f: + try: + config = json.load(f) + except JSONDecodeError as e: + sys.exit( + "Error while parsing the config json:\n" + + e.msg + + " at line no. " + + str(e.lineno) + ) + model_name = config["model_name"] + nodes_to_remove = config["nodes_to_remove"] + gd = load_graph_def_pb(model_name) + to_remove = [n for n in gd.node if n.name in nodes_to_remove] + for i in to_remove: + gd.node.remove(i) + new_graph_name = "processed_" + model_name + dump_graph_def_pb(gd, new_graph_name) + print("Pruned graph is dumped in {}".format(new_graph_name)) diff --git a/Athos/CompilerScripts/tf_graph_io.py b/Athos/CompilerScripts/tf_graph_io.py index e23cc2d1..c8fe8cb1 100644 --- a/Athos/CompilerScripts/tf_graph_io.py +++ b/Athos/CompilerScripts/tf_graph_io.py @@ -29,11 +29,14 @@ def display_graph(graph, tensorboard_log_dir): writer = tf.summary.FileWriter(tensorboard_log_dir, graph) writer.close() - -def load_pb(path_to_pb): +def load_graph_def_pb(path_to_pb): with tf.io.gfile.GFile(path_to_pb, "rb") as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) + return graph_def + +def load_pb(path_to_pb): + graph_def = load_graph_def_pb(path_to_pb) with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def, name="") return graph From c627b56b0dc710e7cf566eec94178f78952be71d Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 3 Feb 2021 17:00:49 +0530 Subject: [PATCH 54/72] Add grouped conv in Porthos, SCI_OT --- Athos/tests/tf/unittests/test_convolution.py | 2 + Porthos/src/EzPCFunctionalities.cpp | 12 ++++++ Porthos/src/EzPCFunctionalities.h | 18 ++++++++ SCI/src/functionalities_wrapper.h | 44 ++++++++++++++++++++ 4 files changed, 76 insertions(+) diff --git a/Athos/tests/tf/unittests/test_convolution.py b/Athos/tests/tf/unittests/test_convolution.py index 8c98ce7f..08dc5216 100644 --- a/Athos/tests/tf/unittests/test_convolution.py +++ b/Athos/tests/tf/unittests/test_convolution.py @@ -77,6 +77,8 @@ def test_conv(test_dir, backend, tfOp, a_shape, kernel_shape, strides, padding, def test_depthwise_conv( test_dir, backend, tfOp, a_shape, kernel_shape, strides, padding, dtype ): + if backend in ["2PC_HE"]: + pytest.skip("[SCI][grouped_conv] Missing Support in SCI") graph = tf.Graph() a_inp = dtype(np.random.randn(*a_shape)) kernel_inp = dtype(np.random.randn(*kernel_shape)) diff --git a/Porthos/src/EzPCFunctionalities.cpp b/Porthos/src/EzPCFunctionalities.cpp index 62de7ce9..3db468cd 100644 --- a/Porthos/src/EzPCFunctionalities.cpp +++ b/Porthos/src/EzPCFunctionalities.cpp @@ -398,6 +398,18 @@ void Conv2DWrapper(int32_t N, int32_t H, int32_t W, int32_t CI, } +void Conv2DGroupWrapper(int32_t N, int32_t H, int32_t W, int32_t CI, + int32_t FH, int32_t FW, int32_t CO, + int32_t zPadHLeft, int32_t zPadHRight, + int32_t zPadWLeft, int32_t zPadWRight, + int32_t strideH, int32_t strideW, int32_t G, + vector< vector< vector< vector > > >& inputArr, + vector< vector< vector< vector > > >& filterArr, + vector< vector< vector< vector > > >& outArr) +{ + Conv2DGroup(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight, strideH, strideW, G, inputArr, filterArr, outArr); +} + void Conv3DWrapper(int32_t N, int32_t D, int32_t H, int32_t W, int32_t CI, int32_t FD, int32_t FH, int32_t FW, int32_t CO, int32_t zPadDLeft, int32_t zPadDRight, diff --git a/Porthos/src/EzPCFunctionalities.h b/Porthos/src/EzPCFunctionalities.h index eadd8b81..782c4b38 100644 --- a/Porthos/src/EzPCFunctionalities.h +++ b/Porthos/src/EzPCFunctionalities.h @@ -113,6 +113,24 @@ void Conv2DWrapper(int32_t N, int32_t H, int32_t W, int32_t CI, vector< vector< vector< vector > > >& filterArr, vector< vector< vector< vector > > >& outArr); +void Conv2DGroup(int32_t N, int32_t H, int32_t W, int32_t CI, + int32_t FH, int32_t FW, int32_t CO, + int32_t zPadHLeft, int32_t zPadHRight, + int32_t zPadWLeft, int32_t zPadWRight, + int32_t strideH, int32_t strideW, int32_t G, + vector< vector< vector< vector > > >& inputArr, + vector< vector< vector< vector > > >& filterArr, + vector< vector< vector< vector > > >& outArr); + +void Conv2DGroupWrapper(int32_t N, int32_t H, int32_t W, int32_t CI, + int32_t FH, int32_t FW, int32_t CO, + int32_t zPadHLeft, int32_t zPadHRight, + int32_t zPadWLeft, int32_t zPadWRight, + int32_t strideH, int32_t strideW, int32_t G, + vector< vector< vector< vector > > >& inputArr, + vector< vector< vector< vector > > >& filterArr, + vector< vector< vector< vector > > >& outArr); + void Conv3D(int32_t N, int32_t D, int32_t H, int32_t W, int32_t CI, int32_t FD, int32_t FH, int32_t FW, int32_t CO, int32_t zPadDLeft, int32_t zPadDRight, diff --git a/SCI/src/functionalities_wrapper.h b/SCI/src/functionalities_wrapper.h index 73ef3d81..983ae3d0 100644 --- a/SCI/src/functionalities_wrapper.h +++ b/SCI/src/functionalities_wrapper.h @@ -42,6 +42,14 @@ void Conv2D(int32_t N, int32_t H, int32_t W, int32_t CI, intType* inputArr, intType* filterArr, intType* outArr); +void Conv2DGroup(int32_t N, int32_t H, int32_t W, int32_t CI, + int32_t FH, int32_t FW, int32_t CO, + int32_t zPadHLeft, int32_t zPadHRight, + int32_t zPadWLeft, int32_t zPadWRight, + int32_t strideH, int32_t strideW, int32_t G, + intType* inputArr, intType* filterArr, + intType* outArr); + void MatMul2D(int32_t s1, int32_t s2, int32_t s3, const intType* A, const intType* B, intType* C, bool modelIsA) { #ifdef LOG_LAYERWISE @@ -948,6 +956,42 @@ void Conv2DWrapper(signedIntType N, signedIntType H, signedIntType W, signedIntT #endif } +void Conv2DGroupWrapper(signedIntType N, signedIntType H, signedIntType W, signedIntType CI, + signedIntType FH, signedIntType FW, signedIntType CO, + signedIntType zPadHLeft, signedIntType zPadHRight, + signedIntType zPadWLeft, signedIntType zPadWRight, + signedIntType strideH, signedIntType strideW, signedIntType G, + intType* inputArr, + intType* filterArr, + intType* outArr) +{ +#ifdef LOG_LAYERWISE + INIT_ALL_IO_DATA_SENT; + INIT_TIMER; +#endif + + static int ctr = 1; + std::cout<<"Conv2DGroupCSF "< Date: Tue, 9 Feb 2021 14:10:09 +0530 Subject: [PATCH 58/72] Add script to convert tf nodes to identity --- .../convert_saved_model_to_frozen_graph.py | 2 +- Athos/CompilerScripts/create_tf_input.py | 4 +- .../replace_tf_nodes_with_identity.py | 66 +++++++++++++++++++ 3 files changed, 68 insertions(+), 4 deletions(-) create mode 100755 Athos/CompilerScripts/replace_tf_nodes_with_identity.py diff --git a/Athos/CompilerScripts/convert_saved_model_to_frozen_graph.py b/Athos/CompilerScripts/convert_saved_model_to_frozen_graph.py index d2c224a9..fc0815ed 100644 --- a/Athos/CompilerScripts/convert_saved_model_to_frozen_graph.py +++ b/Athos/CompilerScripts/convert_saved_model_to_frozen_graph.py @@ -31,7 +31,7 @@ if __name__ == "__main__": if (len(sys.argv) != 2): - print("Usage: python convert_saved_model_to_frozen_graph.py saved_model_directory") + sys.exit("Usage: python convert_saved_model_to_frozen_graph.py saved_model_directory") model_dir = sys.argv[1] saved_model = tf.saved_model.load_v2(model_dir) model = saved_model.signatures["serving_default"] diff --git a/Athos/CompilerScripts/create_tf_input.py b/Athos/CompilerScripts/create_tf_input.py index 35906045..e4bc2854 100644 --- a/Athos/CompilerScripts/create_tf_input.py +++ b/Athos/CompilerScripts/create_tf_input.py @@ -74,9 +74,7 @@ def gen_random_input( f.write(chunk_inp) f.close() if dump_numpy: - rand_inp_t.dump( - model_name + "_input_fixedpt_scale_" + str(scaling_factor) + ".npy" - ) + np.save(model_name + "_input_fixedpt_scale_" + str(scaling_factor), rand_inp_t) return diff --git a/Athos/CompilerScripts/replace_tf_nodes_with_identity.py b/Athos/CompilerScripts/replace_tf_nodes_with_identity.py new file mode 100755 index 00000000..1f809f3a --- /dev/null +++ b/Athos/CompilerScripts/replace_tf_nodes_with_identity.py @@ -0,0 +1,66 @@ +""" + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +""" + +import tensorflow as tf +import sys +import json +from json import JSONDecodeError +from tf_graph_io import load_graph_def_pb, dump_graph_def_pb + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print( + """Usage: python remove_node.py config.json +config.json should have the following fields. +{ + "model_name" : "model.pb", + "nodes_to_replace" : ["loss", "model_outputs"] +} + +""" + ) + sys.exit() + config_path = sys.argv[1] + with open(config_path) as f: + try: + config = json.load(f) + except JSONDecodeError as e: + sys.exit( + "Error while parsing the config json:\n" + + e.msg + + " at line no. " + + str(e.lineno) + ) + model_name = config["model_name"] + nodes_to_replace = config["nodes_to_remove"] + gd = load_graph_def_pb(model_name) + to_replace = [n for n in gd.node if n.name in nodes_to_replace] + for n in gd.node: + if n.name in nodes_to_replace: + n.op = "Identity" + + new_graph_name = "processed_" + model_name + dump_graph_def_pb(gd, new_graph_name) + print("Pruned graph is dumped in {}".format(new_graph_name)) From 71ea77a2ddaa12a19d745027fc7924f62c244fa0 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Tue, 9 Feb 2021 14:10:51 +0530 Subject: [PATCH 59/72] Add support for multiple output nodes. Before we used to treat the last computation as the default output node. However now we manually create output nodes based on user specification. If the user doesnt specify any output, we resort to old behaviour. --- Athos/CompileTFGraph.py | 2 +- Athos/SeeDot/AST/AST.py | 7 ++ Athos/SeeDot/AST/ASTVisitor.py | 10 +++ Athos/SeeDot/AST/MtdAST.py | 3 + Athos/SeeDot/AST/PrintAST.py | 6 ++ Athos/SeeDot/Codegen/CodegenBase.py | 2 + Athos/SeeDot/Codegen/EzPC.py | 17 ++-- Athos/SeeDot/Compiler.py | 78 +++++++++++-------- Athos/SeeDot/IR/IR.py | 13 ++++ Athos/SeeDot/IR/IRBuilderCSF.py | 13 +++- .../SeeDot/Optimizations/GarbageCollector.py | 6 ++ Athos/SeeDot/Type.py | 9 ++- Athos/TFCompiler/ProcessTFGraph.py | 51 +++++++++++- 13 files changed, 174 insertions(+), 43 deletions(-) diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index 27a10bf7..25fb4c5c 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -129,7 +129,7 @@ def generate_code(params, debug=False): ) # Compile to seedot. Generate AST in model directory - Athos.process_tf_graph(model_abs_path) + Athos.process_tf_graph(model_abs_path, output_tensors) # Compile to ezpc model_base_name = os.path.basename(model_abs_path)[:-3] diff --git a/Athos/SeeDot/AST/AST.py b/Athos/SeeDot/AST/AST.py index 3c47b0fc..d69eeb63 100644 --- a/Athos/SeeDot/AST/AST.py +++ b/Athos/SeeDot/AST/AST.py @@ -449,6 +449,13 @@ def __init__( self.isSecret = isSecret self.inputByParty = inputByParty +class Output(ASTNode): + def __init__(self, expr: ASTNode, outputToParty=Party.CLIENT): + if assertInputTypes: + assert (outputToParty in [Party.CLIENT, Party.SERVER]) + super().__init__() + self.expr = expr + self.outputToParty = outputToParty # Since some optimizations are possible around batchnorm, keep this as an interpreted node class FusedBatchNorm(ASTNode): diff --git a/Athos/SeeDot/AST/ASTVisitor.py b/Athos/SeeDot/AST/ASTVisitor.py index 81835f84..367d21ad 100644 --- a/Athos/SeeDot/AST/ASTVisitor.py +++ b/Athos/SeeDot/AST/ASTVisitor.py @@ -26,6 +26,9 @@ class ASTVisitor: + def visitASTNode(self, node: AST.ASTNode, args=None): + pass + def visitInt(self, node: AST.Int, args=None): pass @@ -81,6 +84,9 @@ def visitReduce(self, node: AST.Reduce, args=None): def visitInput(self, node: AST.Input, args=None): pass + def visitOutput(self, node: AST.Output, args=None): + pass + def visitFusedBatchNorm(self, node: AST.FusedBatchNorm, args=None): self.visit(node.expr, args) self.visit(node.multExpr, args) @@ -121,8 +127,12 @@ def visit(self, node, args=None): return self.visitReduce(node, args) elif isinstance(node, AST.Input): return self.visitInput(node, args) + elif isinstance(node, AST.Output): + return self.visitOutput(node, args) elif isinstance(node, AST.FusedBatchNorm): return self.visitFusedBatchNorm(node, args) + elif isinstance(node, AST.ASTNode): + return self.visitASTNode(node, args) elif node: raise Exception("Node instance not matched.") else: diff --git a/Athos/SeeDot/AST/MtdAST.py b/Athos/SeeDot/AST/MtdAST.py index 27c9d938..efd70830 100644 --- a/Athos/SeeDot/AST/MtdAST.py +++ b/Athos/SeeDot/AST/MtdAST.py @@ -90,6 +90,9 @@ def visitReduce(self, node: AST.Reduce, mtd: dict): def visitInput(self, node: AST.Input, mtd: dict): node.metadata.update(mtd) + def visitOutput(self, node: AST.Output, mtd: dict): + node.metadata.update(mtd) + def visitFusedBatchNorm(self, node: AST.FusedBatchNorm, mtd: dict): node.metadata.update(mtd) self.visit(node.expr, mtd) diff --git a/Athos/SeeDot/AST/PrintAST.py b/Athos/SeeDot/AST/PrintAST.py index ccc4b7c0..3411d700 100644 --- a/Athos/SeeDot/AST/PrintAST.py +++ b/Athos/SeeDot/AST/PrintAST.py @@ -148,6 +148,12 @@ def visitInput(self, node: AST.Input, args=None): ) print(" )", end="") + def visitOutput(self, node: AST.Output, args=None): + print(indent * node.depth, "output( ", end="") + node.expr.depth = node.depth + 1 + self.visit(node.expr) + print(indent * node.depth, " )", end="") + def visitFusedBatchNorm(self, node: AST.FusedBatchNorm, args=None): node.expr.depth = node.multExpr.depth = node.addExpr.depth = node.depth + 1 print(indent * node.depth, "FusedBatchNorm", end=" ") diff --git a/Athos/SeeDot/Codegen/CodegenBase.py b/Athos/SeeDot/Codegen/CodegenBase.py index 1e94f857..8e26472b 100644 --- a/Athos/SeeDot/Codegen/CodegenBase.py +++ b/Athos/SeeDot/Codegen/CodegenBase.py @@ -222,6 +222,8 @@ def print(self, ir): return self.printOp(ir) elif isinstance(ir, IR.Input): return self.printInput(ir) + elif isinstance(ir, IR.Output): + return self.printOutput(ir) elif isinstance(ir, IR.Decl): return self.printDecl(ir) else: diff --git a/Athos/SeeDot/Codegen/EzPC.py b/Athos/SeeDot/Codegen/EzPC.py index ac5a3589..b3009369 100644 --- a/Athos/SeeDot/Codegen/EzPC.py +++ b/Athos/SeeDot/Codegen/EzPC.py @@ -109,8 +109,8 @@ def printInt(self, ir: IR.Int): def printInput(self, ir: IR.Input): inputByPartyStr = ir.inputByParty.name assert ( - inputByPartyStr == "SERVER" or inputByPartyStr == "CLIENT" - ) # For now the only supported values of party to input is 0 or 1 + inputByPartyStr in ["SERVER", "CLIENT"] + ) self.out.printf( "input({0}, {1}, ".format(inputByPartyStr, ir.expr.idf), indent=True ) @@ -129,6 +129,15 @@ def printInput(self, ir: IR.Input): self.out.printf("[" + str(curDim) + "]") self.out.printf(");\n\n") + def printOutput(self, ir: IR.Output): + outputByPartyStr = ir.outputToParty.name + assert ( + outputByPartyStr in ["SERVER", "CLIENT"] + ) + self.out.printf( + "output({0}, {1});\n\n".format(outputByPartyStr, ir.expr.idf), indent=True + ) + def printComment(self, ir): self.out.printf("(* " + ir.msg + " *)\n", indent=True) @@ -151,9 +160,5 @@ def printDecl(self, ir): self.out.printf(";\n\n") def _out_suffix(self, expr: IR.Expr): - if self.debugVar is None: - self.out.printf("output(CLIENT, " + expr.idf + ");\n", indent=True) - else: - self.out.printf("output(CLIENT, " + self.debugVar + ");\n", indent=True) self.out.decreaseIndent() self.out.printf("}\n", indent=True) diff --git a/Athos/SeeDot/Compiler.py b/Athos/SeeDot/Compiler.py index ffed6aad..5d1ff409 100644 --- a/Athos/SeeDot/Compiler.py +++ b/Athos/SeeDot/Compiler.py @@ -93,40 +93,55 @@ def insertStartEndFunctionCalls(self, res: (IR.Prog, IR.Expr)): ): prog.cmd_l.insert(ii, IR.FuncCall("StartComputation", [])) break - prog.cmd_l.append(IR.FuncCall("EndComputation", [])) + + first_output_pos = None + for ii in range(len(prog.cmd_l)): + if isinstance(prog.cmd_l[ii], IR.Output): + first_output_pos = ii + break + prog.cmd_l.insert(first_output_pos, IR.FuncCall("EndComputation", [])) return (prog, expr) def fixOuputScale(self, res: (IR.Prog, IR.Expr), compiler: IRBuilderCSF): - prog = res[0] - expr = res[1] - output_scale = compiler.scaleFacMapping[expr.idf] - if output_scale == -1 or output_scale == Util.Config.consSF: - return (prog, expr) - elif output_scale > Util.Config.consSF: - scale_down = output_scale - Util.Config.consSF - type = compiler.typeInfo[expr.idf] - if Type.isInt(type): - output_shape = [] - if Type.isTensor(type): - output_shape = type.shape - - argsDict = OrderedDict() - funcName = "ScaleDown" - for ii, curDimSize in enumerate(output_shape): - argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) - funcName = funcName + str(len(output_shape)) - argsDict[expr] = "expr" - argsDict[IR.Int(scale_down, 32)] = "consSF" - funcCall = IR.FuncCall(funcName, argsDict) - new_prog = IR.Prog([funcCall]) - prog = IRUtil.prog_merge(prog, new_prog) - return (prog, expr) - else: - assert ( - False - ), "Scale up shouldnt be required of final output {} -> {}. We lost precision somewhere".format( - output_scale, Util.Config.consSF - ) + (prog, expr) = res + scaledown_cmd_list = [] + + first_output_pos = None + i = 0 + for cmd in prog.cmd_l: + if type(cmd) == IR.Output: + if first_output_pos is None: + first_output_pos = i + var = cmd.expr + assert type(var) == IR.Var + output_scale = compiler.scaleFacMapping[var.idf] + if output_scale > Util.Config.consSF: + scale_down = output_scale - Util.Config.consSF + var_type = compiler.typeInfo[var.idf] + if Type.isInt(var_type): + output_shape = [] + if Type.isTensor(var_type): + output_shape = var_type.shape + argsDict = OrderedDict() + funcName = "ScaleDown" + for ii, curDimSize in enumerate(output_shape): + argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) + funcName = funcName + str(len(output_shape)) + argsDict[var] = "expr" + argsDict[IR.Int(scale_down, 32)] = "consSF" + funcCall = IR.FuncCall(funcName, argsDict) + scaledown_cmd_list.append(funcCall) + if output_scale < Util.Config.consSF: + assert ( + False + ), "Scale up shouldnt be required of final output {} -> {}. We lost precision somewhere".format( + output_scale, Util.Config.consSF + ) + i+=1 + final_cmd_list = prog.cmd_l[0:first_output_pos] + scaledown_cmd_list + prog.cmd_l[first_output_pos:] + prog = IR.Prog(final_cmd_list) + return (prog, expr) + def run(self): with open(Util.Config.astFile, "rb") as ff: @@ -145,6 +160,7 @@ def run(self): GC.run([mtdAST]) print("Garbage collection done.") + # Perform type inference and annotate nodes with type information InferType().visit(ast) diff --git a/Athos/SeeDot/IR/IR.py b/Athos/SeeDot/IR/IR.py index 89c87102..7b3a84cb 100644 --- a/Athos/SeeDot/IR/IR.py +++ b/Athos/SeeDot/IR/IR.py @@ -374,6 +374,19 @@ def subst(self, from_idf: str, to_e: Expr): self.inputByParty, ) +class Output(Cmd): + def __init__(self, + expr: Expr, + outputToParty: AST.Party + ): + self.expr = expr + self.outputToParty = outputToParty + + def subst(self, from_idf: str, to_e: Expr): + return self.__class__( + self.expr.subst(from_idf, to_e), + self.outputToParty, + ) class Decl(Cmd): def __init__( diff --git a/Athos/SeeDot/IR/IRBuilderCSF.py b/Athos/SeeDot/IR/IRBuilderCSF.py index ee4697e2..fc821585 100644 --- a/Athos/SeeDot/IR/IRBuilderCSF.py +++ b/Athos/SeeDot/IR/IRBuilderCSF.py @@ -1414,8 +1414,8 @@ def visitFloorLike(self, node: AST.Func, args=None): AST.Operators.RSQRT, ]: # Since these class of fucntions can only handle input of 32 bitlength, we have to scale down - # inputs before calling them. - if final_sf > 32: + # inputs before calling them. 23 bit mantissa + if final_sf > 23: assert ( final_sf > self.scaleFac ), "The program scaling factor is invalid. Should be lesser than 32 if network has tan/sig/sqrt" @@ -1622,6 +1622,15 @@ def visitInput(self, node: AST.Input, args=None): returnExpr, ) + def visitOutput(self, node: AST.Output, args=None): + (prog_0, expr_0) = self.visit(node.expr) + output = IR.Output(expr_0, node.outputToParty) + prog = IRUtil.prog_merge(prog_0, IR.Prog([output])) + expr = self.getTempVar() + if not (Util.Config.disableTruncOpti): + self.scaleFacMapping[expr.idf] = self.scaleFac + return (prog, expr) + def visitReduce(self, node: AST.Reduce, args=None): (prog_1, expr1) = self.visit(node.expr) assert node.op in [AST.Operators.ADD, AST.Operators.Mean] diff --git a/Athos/SeeDot/Optimizations/GarbageCollector.py b/Athos/SeeDot/Optimizations/GarbageCollector.py index 1cefa810..deee36c7 100644 --- a/Athos/SeeDot/Optimizations/GarbageCollector.py +++ b/Athos/SeeDot/Optimizations/GarbageCollector.py @@ -45,6 +45,9 @@ def visitFloat(self, node: AST.Float, args): def visitInput(self, node: AST.Input, args): self.node_to_secret[node] = node.isSecret + def visitOutput(self, node: AST.Output, args): + self.node_to_secret[node] = self.idf_to_secret[node.expr.name] + def visitId(self, node: AST.ID, args): self.node_to_secret[node] = self.idf_to_secret[node.name] @@ -239,6 +242,9 @@ def visitFloat(self, node: AST.Float, args): def visitInput(self, node: AST.Input, args): return set() + def visitOutput(self, node: AST.Input, args): + return set() + def visitId(self, node: AST.ID, args): return set([node.name]) diff --git a/Athos/SeeDot/Type.py b/Athos/SeeDot/Type.py index 4d549a4c..74473a4e 100644 --- a/Athos/SeeDot/Type.py +++ b/Athos/SeeDot/Type.py @@ -29,7 +29,7 @@ from AST.ASTVisitor import ASTVisitor from enum import Enum, auto import copy - +import sys class Type: pass @@ -196,6 +196,7 @@ def visitId(self, node: AST.ID, args=None): if node.name not in node.gamma: print( "Error in type checking: Found id which is not contained in gamma.", + node.name, file=sys.stderr, ) assert False @@ -572,6 +573,12 @@ def visitInput(self, node: AST.Input, args=None): ) return node.type + def visitOutput(self, node: AST.Output, args=None): + node.expr.gamma = dict(node.gamma) + self.visit(node.expr) + node.type = Unit() + return node.type + def visitFusedBatchNorm(self, node: AST.FusedBatchNorm, args=None): cur_gamma = dict(node.gamma) node.expr.gamma = cur_gamma diff --git a/Athos/TFCompiler/ProcessTFGraph.py b/Athos/TFCompiler/ProcessTFGraph.py index a6fed45a..40a29104 100644 --- a/Athos/TFCompiler/ProcessTFGraph.py +++ b/Athos/TFCompiler/ProcessTFGraph.py @@ -82,7 +82,7 @@ def generateIRCode(graph, extraInfoDict): ) if program: assert type(innerMostLetASTNode) is AST.Let - newNode = AST.Let(curOutVarAstNode, curAst, curOutVarAstNode) + newNode = AST.Let(curOutVarAstNode, curAst, AST.ASTNode()) mtdAST.visit(newNode, mtdForCurAST) innerMostLetASTNode.expr = newNode innerMostLetASTNode = newNode @@ -98,6 +98,51 @@ def generateIRCode(graph, extraInfoDict): return (program, dictNodeNameToOutVarStr) + +def addOutputs(program, dictNodeNameToOutVarStr, output_tensors): + mtdAST = MtdAST() + assert (type(program) is AST.Let) + lastLetASTNode = program + while True: + if type(lastLetASTNode.expr) is AST.Let: + lastLetASTNode = lastLetASTNode.expr + else: + break + assert lastLetASTNode is not None + if output_tensors is None: + output_name = lastLetASTNode.name + print(output_name.name) + output = AST.Output(output_name, AST.Party.CLIENT) + lastLetASTNode.expr = output + else: + outVarCt = 0 + outVarPrefix = "O" + for i in range(0, len(output_tensors)): # name, decl, expr + t_name = output_tensors[i] + if i == len(output_tensors) - 1: + output_name = AST.ID(dictNodeNameToOutVarStr[t_name]) + output = AST.Output(output_name, AST.Party.CLIENT) + newNode = output + else: + output_name = AST.ID(dictNodeNameToOutVarStr[t_name]) + output = AST.Output(output_name, AST.Party.CLIENT) + let_name_id = AST.ID(outVarPrefix + str(outVarCt)) + newNode = AST.Let(let_name_id, output, AST.ASTNode()) + mtdForCurAST = { + AST.ASTNode.mtdKeyTFOpName: "Output", + AST.ASTNode.mtdKeyTFNodeName: t_name, + } + mtdAST.visit(newNode, mtdForCurAST) + lastLetASTNode.expr = newNode + lastLetASTNode = newNode + outVarCt += 1 + + + print("Wooooo!") + #TODO + return program + + def readSizeInfo(fileName): allLines = None with open(fileName) as f: @@ -237,7 +282,7 @@ def topo_sort(v): return -def process_tf_graph(filename): +def process_tf_graph(filename, output_tensors=None): sys.setrecursionlimit(10000) if os.path.isfile(filename): @@ -283,6 +328,8 @@ def process_tf_graph(filename): print("Generating code from TF graph def : ", graphFileName, " ...") (program, dictNodeNameToOutVarStr) = generateIRCode(graph, extraInfoDict) + program = addOutputs(program, dictNodeNameToOutVarStr, output_tensors) + print("SeeDot AST generation done. Pickling the AST.") with open(os.path.join(folderName, "astOutput.pkl"), "wb") as f: From 51bcb443526d4abdd28f2b4385d0c9f0a042becc Mon Sep 17 00:00:00 2001 From: Bhatu Date: Mon, 15 Feb 2021 13:23:38 +0530 Subject: [PATCH 60/72] Add a pre-commit hook to format python with black. You need to locally do cp Athos/HelperScripts/pre_commit_format_python.sh .git/hooks/pre-commit This will run black on all staged files. --- .../convert_saved_model_to_frozen_graph.py | 6 +- Athos/CompilerScripts/get_output.py | 1 + Athos/CompilerScripts/tf_graph_io.py | 2 + .../HelperScripts/pre_commit_format_python.sh | 65 +++++++++++++++++++ Athos/SeeDot/AST/AST.py | 4 +- Athos/SeeDot/Codegen/EzPC.py | 8 +-- Athos/SeeDot/Compiler.py | 10 +-- Athos/SeeDot/IR/IR.py | 7 +- Athos/SeeDot/Type.py | 1 + Athos/TFCompiler/ProcessTFGraph.py | 9 +-- 10 files changed, 89 insertions(+), 24 deletions(-) create mode 100755 Athos/HelperScripts/pre_commit_format_python.sh diff --git a/Athos/CompilerScripts/convert_saved_model_to_frozen_graph.py b/Athos/CompilerScripts/convert_saved_model_to_frozen_graph.py index fc0815ed..ccb92794 100644 --- a/Athos/CompilerScripts/convert_saved_model_to_frozen_graph.py +++ b/Athos/CompilerScripts/convert_saved_model_to_frozen_graph.py @@ -30,8 +30,10 @@ tf.enable_eager_execution() if __name__ == "__main__": - if (len(sys.argv) != 2): - sys.exit("Usage: python convert_saved_model_to_frozen_graph.py saved_model_directory") + if len(sys.argv) != 2: + sys.exit( + "Usage: python convert_saved_model_to_frozen_graph.py saved_model_directory" + ) model_dir = sys.argv[1] saved_model = tf.saved_model.load_v2(model_dir) model = saved_model.signatures["serving_default"] diff --git a/Athos/CompilerScripts/get_output.py b/Athos/CompilerScripts/get_output.py index d7ca78cb..76a328d5 100644 --- a/Athos/CompilerScripts/get_output.py +++ b/Athos/CompilerScripts/get_output.py @@ -4,6 +4,7 @@ import os import parse_config + def convert_raw_output_to_np(filename, bitlength, scale): matcher = re.compile(r"[-]?[0-9]+") scaled_array = [] diff --git a/Athos/CompilerScripts/tf_graph_io.py b/Athos/CompilerScripts/tf_graph_io.py index c8fe8cb1..82c90c68 100644 --- a/Athos/CompilerScripts/tf_graph_io.py +++ b/Athos/CompilerScripts/tf_graph_io.py @@ -29,12 +29,14 @@ def display_graph(graph, tensorboard_log_dir): writer = tf.summary.FileWriter(tensorboard_log_dir, graph) writer.close() + def load_graph_def_pb(path_to_pb): with tf.io.gfile.GFile(path_to_pb, "rb") as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) return graph_def + def load_pb(path_to_pb): graph_def = load_graph_def_pb(path_to_pb) with tf.Graph().as_default() as graph: diff --git a/Athos/HelperScripts/pre_commit_format_python.sh b/Athos/HelperScripts/pre_commit_format_python.sh new file mode 100755 index 00000000..feefddd3 --- /dev/null +++ b/Athos/HelperScripts/pre_commit_format_python.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash + +# Git pre-commit hook to check staged Python files for formatting issues with +# black. +# +# INSTALLING: Copy this script into `.git/hooks/pre-commit`, and mark it as +# executable. +# +# This requires that black is installed and runnable in the environment running +# the pre-commit hook. +# +# When running, this first checks for unstaged changes to staged files, and if +# there are any, it will exit with an error. Files with unstaged changes will be +# printed. +# +# If all staged files have no unstaged changes, it will run black against them, +# leaving the formatting changes unstaged. Changed files will be printed. +# +# BUGS: This does not leave staged changes alone when used with the -a flag to +# git commit, due to the fact that git stages ALL unstaged files when that flag +# is used. + +# Find all staged Python files, and exit early if there aren't any. +PYTHON_FILES=() +while IFS=$'\n' read -r line; do PYTHON_FILES+=("$line"); done \ + < <(git diff --name-only --cached --diff-filter=AM | grep --color=never '.py$') +if [ ${#PYTHON_FILES[@]} -eq 0 ]; then + exit 0 +fi + +# Verify that black is installed; if not, warn and exit. +if ! command -v black >/dev/null; then + echo 'black not on path; can not format. Please install black:' + echo ' pip install black' + exit 2 +fi + +# Check for unstaged changes to files in the index. +CHANGED_FILES=() +while IFS=$'\n' read -r line; do CHANGED_FILES+=("$line"); done \ + < <(git diff --name-only "${PYTHON_FILES[@]}") +if [ ${#CHANGED_FILES[@]} -gt 0 ]; then + echo 'You have unstaged changes to some files in your commit; skipping ' + echo 'auto-format. Please stage, stash, or revert these changes. You may ' + echo 'find `git stash -k` helpful here.' + echo 'Files with unstaged changes:' "${CHANGED_FILES[@]}" + exit 1 +fi + +# Format all staged files, then exit with an error code if any have uncommitted +# changes. +echo 'Formatting staged Python files . . .' +black -t py37 "${PYTHON_FILES[@]}" + + +CHANGED_FILES=() +while IFS=$'\n' read -r line; do CHANGED_FILES+=("$line"); done \ + < <(git diff --name-only "${PYTHON_FILES[@]}") +if [ ${#CHANGED_FILES[@]} -gt 0 ]; then + echo 'Reformatted staged files. Please review and stage the changes.' + echo 'Files updated: ' "${CHANGED_FILES[@]}" + exit 1 +else + exit 0 +fi diff --git a/Athos/SeeDot/AST/AST.py b/Athos/SeeDot/AST/AST.py index d69eeb63..3209217e 100644 --- a/Athos/SeeDot/AST/AST.py +++ b/Athos/SeeDot/AST/AST.py @@ -449,14 +449,16 @@ def __init__( self.isSecret = isSecret self.inputByParty = inputByParty + class Output(ASTNode): def __init__(self, expr: ASTNode, outputToParty=Party.CLIENT): if assertInputTypes: - assert (outputToParty in [Party.CLIENT, Party.SERVER]) + assert outputToParty in [Party.CLIENT, Party.SERVER] super().__init__() self.expr = expr self.outputToParty = outputToParty + # Since some optimizations are possible around batchnorm, keep this as an interpreted node class FusedBatchNorm(ASTNode): def __init__(self, expr: ID, multExpr: ID, addExpr: ID): diff --git a/Athos/SeeDot/Codegen/EzPC.py b/Athos/SeeDot/Codegen/EzPC.py index b3009369..8178c3da 100644 --- a/Athos/SeeDot/Codegen/EzPC.py +++ b/Athos/SeeDot/Codegen/EzPC.py @@ -108,9 +108,7 @@ def printInt(self, ir: IR.Int): def printInput(self, ir: IR.Input): inputByPartyStr = ir.inputByParty.name - assert ( - inputByPartyStr in ["SERVER", "CLIENT"] - ) + assert inputByPartyStr in ["SERVER", "CLIENT"] self.out.printf( "input({0}, {1}, ".format(inputByPartyStr, ir.expr.idf), indent=True ) @@ -131,9 +129,7 @@ def printInput(self, ir: IR.Input): def printOutput(self, ir: IR.Output): outputByPartyStr = ir.outputToParty.name - assert ( - outputByPartyStr in ["SERVER", "CLIENT"] - ) + assert outputByPartyStr in ["SERVER", "CLIENT"] self.out.printf( "output({0}, {1});\n\n".format(outputByPartyStr, ir.expr.idf), indent=True ) diff --git a/Athos/SeeDot/Compiler.py b/Athos/SeeDot/Compiler.py index 5d1ff409..0bf44138 100644 --- a/Athos/SeeDot/Compiler.py +++ b/Athos/SeeDot/Compiler.py @@ -137,12 +137,15 @@ def fixOuputScale(self, res: (IR.Prog, IR.Expr), compiler: IRBuilderCSF): ), "Scale up shouldnt be required of final output {} -> {}. We lost precision somewhere".format( output_scale, Util.Config.consSF ) - i+=1 - final_cmd_list = prog.cmd_l[0:first_output_pos] + scaledown_cmd_list + prog.cmd_l[first_output_pos:] + i += 1 + final_cmd_list = ( + prog.cmd_l[0:first_output_pos] + + scaledown_cmd_list + + prog.cmd_l[first_output_pos:] + ) prog = IR.Prog(final_cmd_list) return (prog, expr) - def run(self): with open(Util.Config.astFile, "rb") as ff: ast = pickle.load(ff) @@ -160,7 +163,6 @@ def run(self): GC.run([mtdAST]) print("Garbage collection done.") - # Perform type inference and annotate nodes with type information InferType().visit(ast) diff --git a/Athos/SeeDot/IR/IR.py b/Athos/SeeDot/IR/IR.py index 7b3a84cb..09f320df 100644 --- a/Athos/SeeDot/IR/IR.py +++ b/Athos/SeeDot/IR/IR.py @@ -374,11 +374,9 @@ def subst(self, from_idf: str, to_e: Expr): self.inputByParty, ) + class Output(Cmd): - def __init__(self, - expr: Expr, - outputToParty: AST.Party - ): + def __init__(self, expr: Expr, outputToParty: AST.Party): self.expr = expr self.outputToParty = outputToParty @@ -388,6 +386,7 @@ def subst(self, from_idf: str, to_e: Expr): self.outputToParty, ) + class Decl(Cmd): def __init__( self, diff --git a/Athos/SeeDot/Type.py b/Athos/SeeDot/Type.py index 74473a4e..f6cb6d23 100644 --- a/Athos/SeeDot/Type.py +++ b/Athos/SeeDot/Type.py @@ -31,6 +31,7 @@ import copy import sys + class Type: pass diff --git a/Athos/TFCompiler/ProcessTFGraph.py b/Athos/TFCompiler/ProcessTFGraph.py index 40a29104..7eda089e 100644 --- a/Athos/TFCompiler/ProcessTFGraph.py +++ b/Athos/TFCompiler/ProcessTFGraph.py @@ -98,10 +98,9 @@ def generateIRCode(graph, extraInfoDict): return (program, dictNodeNameToOutVarStr) - def addOutputs(program, dictNodeNameToOutVarStr, output_tensors): mtdAST = MtdAST() - assert (type(program) is AST.Let) + assert type(program) is AST.Let lastLetASTNode = program while True: if type(lastLetASTNode.expr) is AST.Let: @@ -117,7 +116,7 @@ def addOutputs(program, dictNodeNameToOutVarStr, output_tensors): else: outVarCt = 0 outVarPrefix = "O" - for i in range(0, len(output_tensors)): # name, decl, expr + for i in range(0, len(output_tensors)): # name, decl, expr t_name = output_tensors[i] if i == len(output_tensors) - 1: output_name = AST.ID(dictNodeNameToOutVarStr[t_name]) @@ -137,9 +136,6 @@ def addOutputs(program, dictNodeNameToOutVarStr, output_tensors): lastLetASTNode = newNode outVarCt += 1 - - print("Wooooo!") - #TODO return program @@ -330,7 +326,6 @@ def process_tf_graph(filename, output_tensors=None): (program, dictNodeNameToOutVarStr) = generateIRCode(graph, extraInfoDict) program = addOutputs(program, dictNodeNameToOutVarStr, output_tensors) - print("SeeDot AST generation done. Pickling the AST.") with open(os.path.join(folderName, "astOutput.pkl"), "wb") as f: pickle.dump(program, f) From 092db31a40642f4214caedf0dc8da888f5ae4d26 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Fri, 19 Feb 2021 18:53:41 +0530 Subject: [PATCH 61/72] add seal include path --- Athos/CompileTFGraph.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index 25fb4c5c..66bb7322 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -252,10 +252,11 @@ def generate_code(params, debug=False): sci_lib = os.path.join(sci, "build", "lib") eigen_path = os.path.join(sci, "extern", "eigen") seal_lib_path = os.path.join(sci, "extern", "SEAL", "native", "lib") + seal_inc_path = os.path.join(sci, "extern", "SEAL", "native", "src") if os.path.exists(sci_lib): os.system( """g++ {opt_flag} -fpermissive -pthread -w -maes -msse4.1 -mavx -mavx2 -mrdseed \ - -faligned-new -std=c++17 -fopenmp -I \"{eigen}\" -I \"{sci_src}\" \"{file}\" \ + -faligned-new -std=c++17 -fopenmp -I \"{eigen}\" -I \"{seal_inc_path}\" -I \"{sci_src}\" \"{file}\" \ -L \"{sci_lib}\" -lSCI-LinearHE -L \"{seal}\" -lseal -lssl -lcrypto \ -o \"{output}\"""".format( eigen=eigen_path, From 5c60161cae462b0a381d2795a5da669d1e58e2c3 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Fri, 19 Feb 2021 18:55:17 +0530 Subject: [PATCH 62/72] Add missing SCI library files --- Athos/TFEzPCLibrary/Library32_sci.ezpc | 126 +++++++++++++++++++++++++ Athos/TFEzPCLibrary/Library64_sci.ezpc | 126 +++++++++++++++++++++++++ 2 files changed, 252 insertions(+) create mode 100644 Athos/TFEzPCLibrary/Library32_sci.ezpc create mode 100644 Athos/TFEzPCLibrary/Library64_sci.ezpc diff --git a/Athos/TFEzPCLibrary/Library32_sci.ezpc b/Athos/TFEzPCLibrary/Library32_sci.ezpc new file mode 100644 index 00000000..8d6197f7 --- /dev/null +++ b/Athos/TFEzPCLibrary/Library32_sci.ezpc @@ -0,0 +1,126 @@ +(* + +Authors: Nishant Kumar. + +Copyright: +Copyright (c) 2020 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +*) + +(**************************) +extern void MatMul2D(int32_pl i, int32_pl j, int32_pl k, int32_al[i][j] A, int32_al[j][k] B, int32_al[i][k] C, bool_pl modelIsA); + +(**************************) +extern void ArgMax(int32_pl s1, int32_pl s2, int32_al[s1][s2] inArr, int32_al[s1] outArr); + +(**************************) +extern void Relu(int32_pl s1, int32_al[s1] inArr, int32_al[s1] outArr, int32_pl sf, bool_pl doTruncation); + +(**************************) +extern void Floor(int32_pl s1, int32_al[s1] inArr, int32_al[s1] outArr, int32_pl sf); + +(**************************) +(* int32_al[N][H][W][C] input *) +extern void MaxPool(int32_pl N, int32_pl H, int32_pl W, int32_pl C, + int32_pl ksizeH, int32_pl ksizeW, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, + int32_pl N1, int32_pl imgH, int32_pl imgW, int32_pl C1, + int32_al[N1][imgH][imgW][C1] inArr, + int32_al[N][H][W][C] outArr); + +(**************************) +(* int32_al[N][H][W][C] input *) +extern void AvgPool(int32_pl N, int32_pl H, int32_pl W, int32_pl C, + int32_pl ksizeH, int32_pl ksizeW, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, + int32_pl N1, int32_pl imgH, int32_pl imgW, int32_pl C1, + int32_al[N1][imgH][imgW][C1] inArr, + int32_al[N][H][W][C] outArr); + +(**************************) +extern void ElemWiseSecretSharedVectorMult(int32_pl s1, int32_al[s1] arr1, int32_al[s1] arr2, int32_al[s1] outArr); +extern void ElemWiseActModelVectorMult(int32_pl s1, int32_al[s1] arr1, int32_al[s1] arr2, int32_al[s1] outArr); +extern void ElemWiseVectorPublicDiv(int32_pl s1, int32_al[s1] arr1, int32_pl divisor, int32_al[s1] outArr); + +(**************************) +extern void ScaleUp(int32_pl s1, int32_al[s1] arr, int32_pl sf); + +(**************************) +extern void ScaleDown(int32_pl s1, int32_al[s1] arr, int32_pl sf); + +(**************************) +extern void ClearMemSecret1(int32_pl s1, int32_al[s1] arr); +extern void ClearMemSecret2(int32_pl s1, int32_pl s2, int32_al[s1][s2] arr); +extern void ClearMemSecret3(int32_pl s1, int32_pl s2, int32_pl s3, int32_al[s1][s2][s3] arr); +extern void ClearMemSecret4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_al[s1][s2][s3][s4] arr); +extern void ClearMemSecret5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_al[s1][s2][s3][s4][s5] arr); +extern void ClearMemPublic(int32_pl x); +extern void ClearMemPublic1(int32_pl s, int32_pl[s] x); +extern void ClearMemPublic2(int32_pl s1, int32_pl s2, int32_pl[s1][s2] arr); +extern void ClearMemPublic3(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl[s1][s2][s3] arr); +extern void ClearMemPublic4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl[s1][s2][s3][s4] arr); +extern void ClearMemPublic5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl[s1][s2][s3][s4][s5] arr); + +(**************************) +extern void StartComputation(); +extern void EndComputation(); + +(**************************) +extern void Conv2DWrapper(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, + int32_al[N][H][W][CI] inputArr, + int32_al[FH][FW][CI][CO] filterArr, + int32_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr); + +extern void Conv3DWrapper(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int32_al[N][D][H][W][CI] inputArr, + int32_al[FD][FH][FW][CI][CO] filterArr, + int32_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr); + +extern void Conv2DGroupWrapper(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, int32_pl G, + int32_al[N][H][W][CI] inputArr, + int32_al[FH][FW][CI/G][CO] filterArr, + int32_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr); + +extern void ConvTranspose2DWrapper(int32_pl N, int32_pl HPrime, int32_pl WPrime, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl H, int32_pl W, + int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, + int32_pl strideH, int32_pl strideW, + int32_al[N][HPrime][WPrime][CI] inputArr, + int32_al[FH][FW][CO][CI] filterArr, + int32_al[N][H][W][CO] outArr); + +extern void ConvTranspose3DWrapper(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl D, int32_pl H, int32_pl W, + int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int32_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int32_al[FD][FH][FW][CO][CI] filterArr, + int32_al[N][D][H][W][CO] outArr); diff --git a/Athos/TFEzPCLibrary/Library64_sci.ezpc b/Athos/TFEzPCLibrary/Library64_sci.ezpc new file mode 100644 index 00000000..5a1a6b35 --- /dev/null +++ b/Athos/TFEzPCLibrary/Library64_sci.ezpc @@ -0,0 +1,126 @@ +(* + +Authors: Nishant Kumar. + +Copyright: +Copyright (c) 2020 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +*) + +(**************************) +extern void MatMul2D(int32_pl i, int32_pl j, int32_pl k, int64_al[i][j] A, int64_al[j][k] B, int64_al[i][k] C, bool_pl modelIsA); + +(**************************) +extern void ArgMax(int32_pl s1, int32_pl s2, int64_al[s1][s2] inArr, int64_al[s1] outArr); + +(**************************) +extern void Relu(int32_pl s1, int64_al[s1] inArr, int64_al[s1] outArr, int32_pl sf, bool_pl doTruncation); + +(**************************) +extern void Floor(int32_pl s1, int64_al[s1] inArr, int64_al[s1] outArr, int32_pl sf); + +(**************************) +(* int64_al[N][H][W][C] input *) +extern void MaxPool(int32_pl N, int32_pl H, int32_pl W, int32_pl C, + int32_pl ksizeH, int32_pl ksizeW, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, + int32_pl N1, int32_pl imgH, int32_pl imgW, int32_pl C1, + int64_al[N1][imgH][imgW][C1] inArr, + int64_al[N][H][W][C] outArr); + +(**************************) +(* int64_al[N][H][W][C] input *) +extern void AvgPool(int32_pl N, int32_pl H, int32_pl W, int32_pl C, + int32_pl ksizeH, int32_pl ksizeW, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, + int32_pl N1, int32_pl imgH, int32_pl imgW, int32_pl C1, + int64_al[N1][imgH][imgW][C1] inArr, + int64_al[N][H][W][C] outArr); + +(**************************) +extern void ElemWiseSecretSharedVectorMult(int32_pl s1, int64_al[s1] arr1, int64_al[s1] arr2, int64_al[s1] outArr); +extern void ElemWiseActModelVectorMult(int32_pl s1, int64_al[s1] arr1, int64_al[s1] arr2, int64_al[s1] outArr); +extern void ElemWiseVectorPublicDiv(int32_pl s1, int64_al[s1] arr1, int32_pl divisor, int64_al[s1] outArr); + +(**************************) +extern void ScaleUp(int32_pl s1, int64_al[s1] arr, int32_pl sf); + +(**************************) +extern void ScaleDown(int32_pl s1, int64_al[s1] arr, int32_pl sf); + +(**************************) +extern void ClearMemSecret1(int32_pl s1, int64_al[s1] arr); +extern void ClearMemSecret2(int32_pl s1, int32_pl s2, int64_al[s1][s2] arr); +extern void ClearMemSecret3(int32_pl s1, int32_pl s2, int32_pl s3, int64_al[s1][s2][s3] arr); +extern void ClearMemSecret4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int64_al[s1][s2][s3][s4] arr); +extern void ClearMemSecret5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int64_al[s1][s2][s3][s4][s5] arr); +extern void ClearMemPublic(int32_pl x); +extern void ClearMemPublic1(int32_pl s, int32_pl[s] x); +extern void ClearMemPublic2(int32_pl s1, int32_pl s2, int32_pl[s1][s2] arr); +extern void ClearMemPublic3(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl[s1][s2][s3] arr); +extern void ClearMemPublic4(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl[s1][s2][s3][s4] arr); +extern void ClearMemPublic5(int32_pl s1, int32_pl s2, int32_pl s3, int32_pl s4, int32_pl s5, int32_pl[s1][s2][s3][s4][s5] arr); + +(**************************) +extern void StartComputation(); +extern void EndComputation(); + +(**************************) +extern void Conv2DWrapper(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, + int64_al[N][H][W][CI] inputArr, + int64_al[FH][FW][CI][CO] filterArr, + int64_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr); + +extern void Conv3DWrapper(int32_pl N, int32_pl D, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadDLeft, int32_pl zPadDRight, int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int64_al[N][D][H][W][CI] inputArr, + int64_al[FD][FH][FW][CI][CO] filterArr, + int64_al[N][((D-FD+(zPadDLeft+zPadDRight))/strideD)+1][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr); + +extern void Conv2DGroupWrapper(int32_pl N, int32_pl H, int32_pl W, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl zPadHLeft, int32_pl zPadHRight, int32_pl zPadWLeft, int32_pl zPadWRight, + int32_pl strideH, int32_pl strideW, int32_pl G, + int64_al[N][H][W][CI] inputArr, + int64_al[FH][FW][CI/G][CO] filterArr, + int64_al[N][((H-FH+(zPadHLeft+zPadHRight))/strideH)+1][((W-FW+(zPadWLeft+zPadWRight))/strideW)+1][CO] outArr); + +extern void ConvTranspose2DWrapper(int32_pl N, int32_pl HPrime, int32_pl WPrime, int32_pl CI, + int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl H, int32_pl W, + int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, + int32_pl strideH, int32_pl strideW, + int64_al[N][HPrime][WPrime][CI] inputArr, + int64_al[FH][FW][CO][CI] filterArr, + int64_al[N][H][W][CO] outArr); + +extern void ConvTranspose3DWrapper(int32_pl N, int32_pl DPrime, int32_pl HPrime, int32_pl WPrime, int32_pl CI, + int32_pl FD, int32_pl FH, int32_pl FW, int32_pl CO, + int32_pl D, int32_pl H, int32_pl W, + int32_pl zPadTrDLeft, int32_pl zPadTrDRight, int32_pl zPadTrHLeft, int32_pl zPadTrHRight, int32_pl zPadTrWLeft, int32_pl zPadTrWRight, + int32_pl strideD, int32_pl strideH, int32_pl strideW, + int64_al[N][DPrime][HPrime][WPrime][CI] inputArr, + int64_al[FD][FH][FW][CO][CI] filterArr, + int64_al[N][D][H][W][CO] outArr); From bf45544b3f5ab144adc13f01bc4e5da487718b0c Mon Sep 17 00:00:00 2001 From: Bhatu Date: Fri, 19 Feb 2021 18:58:29 +0530 Subject: [PATCH 63/72] Bail out compilation if we encounter unsupported ops --- Athos/CompilerScripts/compile_tf.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/Athos/CompilerScripts/compile_tf.py b/Athos/CompilerScripts/compile_tf.py index 044e9fbb..1ada95cb 100644 --- a/Athos/CompilerScripts/compile_tf.py +++ b/Athos/CompilerScripts/compile_tf.py @@ -40,6 +40,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "TFCompiler")) import DumpTFMtData +from TFNodesAST import TFNodesAST def get_graph_from(graph_def): @@ -102,6 +103,15 @@ def infer_input_info(graph): return input_t_info +def get_unsupported_ops(graph): + ops = set([i.type for i in graph.get_operations()]) + unsupported_ops = [] + for op in ops: + if not hasattr(TFNodesAST, op): + unsupported_ops.append(op) + return unsupported_ops + + # Generates the computation graph and tensor size metadata and saves them in # the model directory. # Optionaly dumps model weights as fixedpt in specified scaling factor @@ -111,6 +121,15 @@ def compile(model_fname, input_t_info, output_t_names, scaling_factor, save_weig graph = tf_graph_io.load_pb(model_fname) assert tensors_exist(graph, output_t_names) + unsupported_ops = get_unsupported_ops(graph) + if len(unsupported_ops) != 0: + msg = ( + "Exiting compilation...\nCurrently we do not support the following ops: \n" + ) + for i in unsupported_ops: + msg = msg + " " + i + "\n" + sys.exit(msg) + if input_t_info == {}: input_t_info = infer_input_info(graph) else: From 3976119d28662ec4af5833895c1026f20cb03f0a Mon Sep 17 00:00:00 2001 From: Bhatu Date: Fri, 19 Feb 2021 19:46:23 +0530 Subject: [PATCH 64/72] typo --- Athos/CompileTFGraph.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index 66bb7322..3df3a47a 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -263,6 +263,7 @@ def generate_code(params, debug=False): sci_src=sci_src, file=output_file, sci_lib=sci_lib, + seal_inc_path=seal_inc_path, seal=seal_lib_path, output=program_path, opt_flag=opt_flag, From f7e15792ab9866e9deb38429141cf098331353b1 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Sat, 20 Feb 2021 02:36:07 +0530 Subject: [PATCH 65/72] Script to convert numpy arr to fixedpt --- .../CompilerScripts/convert_np_to_fixedpt.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 Athos/CompilerScripts/convert_np_to_fixedpt.py diff --git a/Athos/CompilerScripts/convert_np_to_fixedpt.py b/Athos/CompilerScripts/convert_np_to_fixedpt.py new file mode 100644 index 00000000..50e22149 --- /dev/null +++ b/Athos/CompilerScripts/convert_np_to_fixedpt.py @@ -0,0 +1,43 @@ +import argparse +from argparse import RawTextHelpFormatter + +import parse_config +import os +import numpy as np + + +def parse_args(): + parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter) + parser.add_argument( + "--inp", + required=True, + type=str, + help="Path to numpy array dumped using np.save (.npy file)", + ) + parser.add_argument( + "--config", required=True, type=str, help="Path to the config json file" + ) + args = parser.parse_args() + return args + + +def convert_np_to_fixedpt(path_to_numpy_arr, scaling_factor): + if not os.path.exists(path_to_numpy_arr): + sys.exit("Numpy arr {} specified does not exist".format(path_to_numpy_arr)) + input_name = os.path.splitext(path_to_numpy_arr)[0] + output_path = input_name + "_fixedpt_scale_" + str(scaling_factor) + ".inp" + + np_inp = np.load(path_to_numpy_arr, allow_pickle=True) + with open(output_path, "w") as ff: + for xx in np.nditer(np_inp, order="C"): + ff.write(str(int(xx * (1 << scaling_factor))) + " ") + ff.write("\n") + return output_path + + +if __name__ == "__main__": + args = parse_args() + params = parse_config.get_params(args.config) + scale = 12 if params["scale"] is None else params["scale"] + output_path = convert_np_to_fixedpt(args.inp, scale) + print("Fixed point output saved in ", output_path) From cacf10f31cddf55e4a06908fcfc64f8d7d0f85bd Mon Sep 17 00:00:00 2001 From: Bhatu Date: Sat, 20 Feb 2021 02:42:08 +0530 Subject: [PATCH 66/72] Script improvements. Split server/client --- Athos/CompileTFGraph.py | 60 ++++++++++++++++++++++----- Athos/CompilerScripts/compile_tf.py | 9 ++++ Athos/CompilerScripts/get_output.py | 11 +++-- Athos/CompilerScripts/parse_config.py | 2 - Athos/CompilerScripts/tf_graph_io.py | 6 +++ Athos/TFCompiler/ProcessTFGraph.py | 2 +- Athos/tests/utils.py | 2 +- 7 files changed, 74 insertions(+), 18 deletions(-) diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index 3df3a47a..684e6525 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -28,6 +28,7 @@ import os.path import json import sys +from zipfile import ZipFile import TFCompiler.ProcessTFGraph as Athos import CompilerScripts.parse_config as parse_config @@ -36,6 +37,16 @@ def parse_args(): parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter) + parser.add_argument( + "--role", + required=True, + type=str, + choices=["server", "client"], + help=""" +Choose server if you are the model owner. +Choose client if you are the data owner. +""", + ) parser.add_argument( "--config", required=True, @@ -56,7 +67,7 @@ def parse_args(): //--------------------------- Optional options --------------------------- "scale":10, // Scaling factor to compile for. DEFAULT=12. - "bitlength":64, // Bit length to compile for. DEFAULT=12. + "bitlength":64, // Bit length to compile for. DEFAULT=64. "save_weights" : true, // Save model scaled weights in fixed point. DEFAULT=true. "input_tensors":{ // Name and shape of the input tensors @@ -82,13 +93,19 @@ def parse_args(): return args -def generate_code(params, debug=False): +def generate_code(params, role, debug=False): model_name = params["model_name"] input_tensor_info = params["input_tensors"] output_tensors = params["output_tensors"] scale = 12 if params["scale"] is None else params["scale"] - bitlength = 64 if params["bitlength"] is None else params["bitlength"] target = params["target"] + if params["bitlength"] is None: + if target == "SCI": + bitlength = 63 + else: + bitlength = 64 + else: + bitlength = params["bitlength"] save_weights = True if params["save_weights"] is None else params["save_weights"] disable_all_hlil_opts = ( False @@ -123,13 +140,29 @@ def generate_code(params, debug=False): athos_dir = os.path.dirname(os.path.abspath(__file__)) model_abs_path = os.path.abspath(model_name) model_abs_dir = os.path.dirname(model_abs_path) - # Generate graphdef and sizeInfo metadata - weights_path = compile_tf.compile( - model_name, input_tensor_info, output_tensors, scale, save_weights - ) + + pruned_model_path = os.path.join(model_abs_dir, "optimised_" + model_name) + if role == "server": + # Generate graphdef and sizeInfo metadata + weights_path = compile_tf.compile( + model_name, input_tensor_info, output_tensors, scale, save_weights + ) + # Zip the pruned model, sizeInfo to send to client + file_list = [ + pruned_model_path, + os.path.join(model_abs_dir, "sizeInfo.mtdata"), + ] + if "config_name" in params: + file_list.append(params["config_name"]) + zip_path = os.path.join(model_abs_dir, "client.zip") + with ZipFile(zip_path, "w") as zip: + for file in file_list: + zip.write(file, os.path.basename(file)) + elif role == "client": + compile_tf.save_graph_def(pruned_model_path) # Compile to seedot. Generate AST in model directory - Athos.process_tf_graph(model_abs_path, output_tensors) + Athos.process_tf_graph(model_abs_dir, output_tensors) # Compile to ezpc model_base_name = os.path.basename(model_abs_path)[:-3] @@ -275,12 +308,17 @@ def generate_code(params, debug=False): ) os.chdir(cwd) - print("Generated binary: {}".format(program_path)) - print("Use as input to server (model weights): {}".format(weights_path)) + print("\n\nGenerated binary: {}".format(program_path)) + if role == "server": + print("\n\nUse as input to server (model weights): {}".format(weights_path)) + print("Share {} file with the client".format(zip_path)) + if role == "client": + weights_path = "" return (program_path, weights_path) if __name__ == "__main__": args = parse_args() params = parse_config.get_params(args.config) - generate_code(params) + params["config_name"] = args.config + generate_code(params, args.role) diff --git a/Athos/CompilerScripts/compile_tf.py b/Athos/CompilerScripts/compile_tf.py index 1ada95cb..90cc60ed 100644 --- a/Athos/CompilerScripts/compile_tf.py +++ b/Athos/CompilerScripts/compile_tf.py @@ -30,6 +30,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + import tensorflow as tf import numpy as np @@ -43,6 +44,14 @@ from TFNodesAST import TFNodesAST +def save_graph_def(path_to_pb): + if not os.path.exists(path_to_pb): + sys.exit("Cannot find " + path_to_pb) + gd = tf_graph_io.load_graph_def_pb(path_to_pb) + DumpTFMtData.save_graphdef(gd) + return + + def get_graph_from(graph_def): with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def, name="") diff --git a/Athos/CompilerScripts/get_output.py b/Athos/CompilerScripts/get_output.py index 76a328d5..aa49aac2 100644 --- a/Athos/CompilerScripts/get_output.py +++ b/Athos/CompilerScripts/get_output.py @@ -28,9 +28,14 @@ def convert_raw_output_to_np(filename, bitlength, scale): output_fname = sys.argv[1] config_name = sys.argv[2] params = parse_config.get_params(config_name) - scale = params["scale"] - bitlength = params["bitlength"] - bitlength = 64 if bitlength is None else bitlength + scale = 12 if params["scale"] is None else params["scale"] + if params["bitlength"] is None: + if target == "SCI": + bitlength = 63 + else: + bitlength = 64 + else: + bitlength = params["bitlength"] model_name = os.path.splitext(params["model_name"])[0] np_arr = convert_raw_output_to_np(output_fname, bitlength, scale) np.save(model_name + "_output", np_arr) diff --git a/Athos/CompilerScripts/parse_config.py b/Athos/CompilerScripts/parse_config.py index d55c85eb..2bbc2c05 100644 --- a/Athos/CompilerScripts/parse_config.py +++ b/Athos/CompilerScripts/parse_config.py @@ -174,8 +174,6 @@ def parse_config(config, sample_network=False): + " is not a tensorflow protobuf file. Please supply " + "a valid tensorflow protobuf model (.pb extension)" ) - if not os.path.isfile(model_fname): - sys.exit(model_fname + " file does not exist") else: network_name = get_str_param(config, "network_name") run_in_tmux = get_opt_bool_param(config, "run_in_tmux") diff --git a/Athos/CompilerScripts/tf_graph_io.py b/Athos/CompilerScripts/tf_graph_io.py index 82c90c68..b78e5716 100644 --- a/Athos/CompilerScripts/tf_graph_io.py +++ b/Athos/CompilerScripts/tf_graph_io.py @@ -23,6 +23,8 @@ """ import tensorflow as tf from tensorflow.python.platform import gfile +import os.path +import sys def display_graph(graph, tensorboard_log_dir): @@ -31,6 +33,8 @@ def display_graph(graph, tensorboard_log_dir): def load_graph_def_pb(path_to_pb): + if not os.path.isfile(path_to_pb): + sys.exit(path_to_pb + " file does not exist") with tf.io.gfile.GFile(path_to_pb, "rb") as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) @@ -38,6 +42,8 @@ def load_graph_def_pb(path_to_pb): def load_pb(path_to_pb): + if not os.path.isfile(path_to_pb): + sys.exit(path_to_pb + " file does not exist") graph_def = load_graph_def_pb(path_to_pb) with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def, name="") diff --git a/Athos/TFCompiler/ProcessTFGraph.py b/Athos/TFCompiler/ProcessTFGraph.py index 7eda089e..0c4b7461 100644 --- a/Athos/TFCompiler/ProcessTFGraph.py +++ b/Athos/TFCompiler/ProcessTFGraph.py @@ -283,7 +283,7 @@ def process_tf_graph(filename, output_tensors=None): if os.path.isfile(filename): folderName = os.path.dirname(filename) - elif os.path.isdir(filename): + else: folderName = filename graphFileName = os.path.join(folderName, "graphDef.mtdata") graph = Graph.Graph() diff --git a/Athos/tests/utils.py b/Athos/tests/utils.py index 9b05da09..bc975621 100644 --- a/Athos/tests/utils.py +++ b/Athos/tests/utils.py @@ -202,7 +202,7 @@ def compile_and_run(self, inputs, timeoutSeconds=40): params = get_params(self.config) print(params) (output_program, model_weight_file) = CompileTFGraph.generate_code( - params, debug=False + params, role="server", debug=False ) prog = Program(output_program, model_weight_file, params, self.test_dir) output = prog.run(inputs, timeoutSeconds) From cb9496acc61ee55141c4120d26c3f5dbbeaadec5 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 24 Feb 2021 14:04:00 +0530 Subject: [PATCH 67/72] Updated scripts --- .../comparison_scripts/compare_np_arrs.py | 45 +++++++++++++++++++ Athos/CompilerScripts/get_output.py | 2 +- 2 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 Athos/CompilerScripts/comparison_scripts/compare_np_arrs.py diff --git a/Athos/CompilerScripts/comparison_scripts/compare_np_arrs.py b/Athos/CompilerScripts/comparison_scripts/compare_np_arrs.py new file mode 100644 index 00000000..6609f7a8 --- /dev/null +++ b/Athos/CompilerScripts/comparison_scripts/compare_np_arrs.py @@ -0,0 +1,45 @@ +""" + +Authors: Pratik Bhatu. + +Copyright: +Copyright (c) 2021 Microsoft Research +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +""" +import numpy as np +import sys + + +if __name__ == "__main__": + if len(sys.argv) != 3: + sys.exit("Usage: compare_np_arrs.py arr1.npy arr2.npy") + arr1 = np.load(sys.argv[1], allow_pickle=True).flatten() + arr2 = np.load(sys.argv[2], allow_pickle=True).flatten() + + matching_prec = -1 + for prec in range(1, 10): + try: + np.testing.assert_almost_equal(arr1, arr2, decimal=prec) + except AssertionError: + break + matching_prec = prec + + if matching_prec == -1: + print("Output mismatch") + else: + print("Arrays matched upto {} decimal points".format(matching_prec)) diff --git a/Athos/CompilerScripts/get_output.py b/Athos/CompilerScripts/get_output.py index aa49aac2..7dad51b3 100644 --- a/Athos/CompilerScripts/get_output.py +++ b/Athos/CompilerScripts/get_output.py @@ -30,7 +30,7 @@ def convert_raw_output_to_np(filename, bitlength, scale): params = parse_config.get_params(config_name) scale = 12 if params["scale"] is None else params["scale"] if params["bitlength"] is None: - if target == "SCI": + if params["target"] == "SCI": bitlength = 63 else: bitlength = 64 From c053ea2e0e26f28a986bdf88c71446c2057fdb5b Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 24 Feb 2021 16:58:29 +0530 Subject: [PATCH 68/72] Add quick mode to setup script --- README.md | 2 +- setup_env_and_build.sh | 47 +++++++++++++++++++++++++++++++++++------- 2 files changed, 41 insertions(+), 8 deletions(-) mode change 100644 => 100755 setup_env_and_build.sh diff --git a/README.md b/README.md index 92dff75f..e8609298 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ With these components in place, we are able to run for the first time secure inf ## Setup For setup instructions, please refer to each of the components' readme. -Alternatively you can use the **setup_env_and_build.sh** script. It installs dependencies and builds each component. It also creates a virtual environment in a *mpc_venv* folder with all the required packages. +Alternatively you can use the **setup_env_and_build.sh** script. It installs dependencies and builds each component. It also creates a virtual environment in a *mpc_venv* folder with all the required packages.If you want to do setup with default paths and settings do ``./setup_env_and_build.sh quick``, otherwise you want to manually choose paths you can use ``./setup_env_and_build.sh``. Please do ``source mpc_venv/bin/activate`` before using the toolchain. diff --git a/setup_env_and_build.sh b/setup_env_and_build.sh old mode 100644 new mode 100755 index 3cd72eac..79ca125b --- a/setup_env_and_build.sh +++ b/setup_env_and_build.sh @@ -20,19 +20,52 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -sudo add-apt-repository ppa:deadsnakes/ppa -sudo add-apt-repository ppa:avsm/ppa +mode=$1 + +sudo add-apt-repository ppa:deadsnakes/ppa -y +sudo add-apt-repository ppa:avsm/ppa -y sudo apt update sudo apt install -y build-essential make cmake libgmp-dev libglib2.0-dev libssl-dev libboost-all-dev m4 python3.7 opam +sudo apt install -y unzip bubblewrap + +wget "https://raw.githubusercontent.com/ocaml/opam/master/shell/install.sh" +if [ $? -ne 0 ]; then + echo "Downloading of opam script failed"; exit +fi +chmod +x install.sh +if [[ "$mode" == "quick" ]]; then + yes "" | ./install.sh +else + ./install.sh +fi +if [ $? -ne 0 ]; then + rm install.sh + echo "Opam installation failed"; exit +fi +rm install.sh -sudo apt install unzip bubblewrap -sh <(curl -sL https://raw.githubusercontent.com/ocaml/opam/master/shell/install.sh) # environment setup -opam init -eval `opam env` +if [[ "$mode" == "quick" ]]; then + yes "" | opam init +else + opam init +fi +if [ $? -ne 0 ]; then + echo "opam init failed"; exit +fi + # install given version of the compiler -opam switch create 4.10.0 +eval `opam env` +if [[ "$mode" == "quick" ]]; then + yes "" | opam switch create 4.10.0 +else + opam switch create 4.10.0 +fi +opam switch list | grep "4.10.0" 2>/dev/null +if [ $? -ne 0 ]; then + echo "opam switch create 4.10.0 failed"; exit +fi eval `opam env` # check if we got what we wanted which ocaml From 2b006d719db733881c1bfdcbf61c8c802f2b9e56 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 24 Feb 2021 17:07:08 +0530 Subject: [PATCH 69/72] Fix typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e8609298..fd880721 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ With these components in place, we are able to run for the first time secure inf ## Setup For setup instructions, please refer to each of the components' readme. -Alternatively you can use the **setup_env_and_build.sh** script. It installs dependencies and builds each component. It also creates a virtual environment in a *mpc_venv* folder with all the required packages.If you want to do setup with default paths and settings do ``./setup_env_and_build.sh quick``, otherwise you want to manually choose paths you can use ``./setup_env_and_build.sh``. +Alternatively you can use the **setup_env_and_build.sh** script. It installs dependencies and builds each component. It also creates a virtual environment in a *mpc_venv* folder with all the required packages. If you want to do setup with default paths and settings do ``./setup_env_and_build.sh quick``, otherwise if you want to manually choose paths you can use ``./setup_env_and_build.sh``. Please do ``source mpc_venv/bin/activate`` before using the toolchain. From 01488724c219a070db4131b6d89befad25fc5a68 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 24 Feb 2021 17:34:58 +0530 Subject: [PATCH 70/72] Add verbose print option while comparing np arrays --- .../comparison_scripts/compare_np_arrs.py | 33 ++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/Athos/CompilerScripts/comparison_scripts/compare_np_arrs.py b/Athos/CompilerScripts/comparison_scripts/compare_np_arrs.py index 6609f7a8..8f39d741 100644 --- a/Athos/CompilerScripts/comparison_scripts/compare_np_arrs.py +++ b/Athos/CompilerScripts/comparison_scripts/compare_np_arrs.py @@ -23,13 +23,34 @@ """ import numpy as np import sys +import argparse + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--inputs", + help="Paths of the two numpy arrays to compare. -i arr1.npy arr2.npy ", + required=True, + type=str, + nargs=2, + ) + parser.add_argument( + "-v", + "--verbose", + help="Verbose mode. Prints arrays.", + action="store_true", + ) + args = parser.parse_args() + return args if __name__ == "__main__": - if len(sys.argv) != 3: - sys.exit("Usage: compare_np_arrs.py arr1.npy arr2.npy") - arr1 = np.load(sys.argv[1], allow_pickle=True).flatten() - arr2 = np.load(sys.argv[2], allow_pickle=True).flatten() + args = parse_args() + + arr1 = np.load(args.inputs[0], allow_pickle=True).flatten() + arr2 = np.load(args.inputs[1], allow_pickle=True).flatten() matching_prec = -1 for prec in range(1, 10): @@ -43,3 +64,7 @@ print("Output mismatch") else: print("Arrays matched upto {} decimal points".format(matching_prec)) + + if args.verbose: + print(args.inputs[0], ": ", arr1) + print(args.inputs[1], ": ", arr2) From 3f72d1519529279a47d9c2bc01799d7e65db07e1 Mon Sep 17 00:00:00 2001 From: Bhatu Date: Thu, 25 Feb 2021 13:32:28 +0530 Subject: [PATCH 71/72] Fix parsing of tensor contents from graphdef --- Athos/TFCompiler/Graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Athos/TFCompiler/Graph.py b/Athos/TFCompiler/Graph.py index 194b6b65..c8bdaef3 100644 --- a/Athos/TFCompiler/Graph.py +++ b/Athos/TFCompiler/Graph.py @@ -322,7 +322,7 @@ def readFromFilePointer(self, fileP, cnt): elif curToken == "tensor_content:": if errIfTokensNotMinLen(tokens, 2, cnt, "Tensor"): return (False, cnt) - self.__tensorContentInput = tokens[1] + self.__tensorContentInput = " ".join(tokens[1:]) self.__convToBytes() elif curToken == "float_val:": if errIfTokensNotMinLen(tokens, 2, cnt, "Tensor"): From a86f7fbb47903ff91744f6aaf8fc674559bcc1eb Mon Sep 17 00:00:00 2001 From: Bhatu Date: Wed, 10 Mar 2021 17:55:33 +0530 Subject: [PATCH 72/72] [SCI] Fix bug in maxpool Fix padding to 8 for multithreading. --- Athos/CompileTFGraph.py | 9 +++++---- Athos/tests/tf/unittests/test_unaryops.py | 2 -- SCI/src/functionalities_wrapper.h | 6 +++++- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/Athos/CompileTFGraph.py b/Athos/CompileTFGraph.py index 684e6525..860d937a 100644 --- a/Athos/CompileTFGraph.py +++ b/Athos/CompileTFGraph.py @@ -94,7 +94,7 @@ def parse_args(): def generate_code(params, role, debug=False): - model_name = params["model_name"] + model_path = params["model_name"] input_tensor_info = params["input_tensors"] output_tensors = params["output_tensors"] scale = 12 if params["scale"] is None else params["scale"] @@ -138,14 +138,15 @@ def generate_code(params, role, debug=False): cwd = os.getcwd() athos_dir = os.path.dirname(os.path.abspath(__file__)) - model_abs_path = os.path.abspath(model_name) + model_name = os.path.basename(model_path) + model_abs_path = os.path.abspath(model_path) model_abs_dir = os.path.dirname(model_abs_path) pruned_model_path = os.path.join(model_abs_dir, "optimised_" + model_name) if role == "server": # Generate graphdef and sizeInfo metadata weights_path = compile_tf.compile( - model_name, input_tensor_info, output_tensors, scale, save_weights + model_path, input_tensor_info, output_tensors, scale, save_weights ) # Zip the pruned model, sizeInfo to send to client file_list = [ @@ -165,7 +166,7 @@ def generate_code(params, role, debug=False): Athos.process_tf_graph(model_abs_dir, output_tensors) # Compile to ezpc - model_base_name = os.path.basename(model_abs_path)[:-3] + model_base_name = model_name[:-3] ezpc_file_name = "{mname}_{bl}_{target}.ezpc".format( mname=model_base_name, bl=bitlength, target=target.lower() ) diff --git a/Athos/tests/tf/unittests/test_unaryops.py b/Athos/tests/tf/unittests/test_unaryops.py index 6477de6b..026ec145 100644 --- a/Athos/tests/tf/unittests/test_unaryops.py +++ b/Athos/tests/tf/unittests/test_unaryops.py @@ -149,8 +149,6 @@ def test_argmax(test_dir, backend, a_shape, axis, dtype): def test_pool( test_dir, backend, tfOp, a_shape, ksize, strides, padding, data_format, dtype ): - if backend.startswith("2PC") and tfOp == tf.nn.max_pool: - pytest.skip("[SCI][maxpool] Output mismatch bug") graph = tf.Graph() a_inp = dtype(np.random.randn(*a_shape)) with graph.as_default(): diff --git a/SCI/src/functionalities_wrapper.h b/SCI/src/functionalities_wrapper.h index 983ae3d0..5cb86213 100644 --- a/SCI/src/functionalities_wrapper.h +++ b/SCI/src/functionalities_wrapper.h @@ -30,8 +30,10 @@ SOFTWARE. #include #include + #ifdef VERIFY_LAYERWISE #include "functionalities_pt.h" + #endif void Conv2D(int32_t N, int32_t H, int32_t W, int32_t CI, @@ -524,7 +526,9 @@ void MaxPool(int32_t N, int32_t H, int32_t W, int32_t C, } for(int i=rowsOrig;i