-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathQATexport.py
35 lines (33 loc) · 1.05 KB
/
QATexport.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import argparse
import time
import torch
from pathlib import Path
from timm.models import create_model
import levit
import levit_c
def AutoExport(check_add):
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
quant_nn.TensorQuantizer.use_fb_fake_quant = True
quant_modules.initialize()
model = create_model(
'LeViT_128S',
num_classes=1000,
distillation=False,
pretrained=False,
fuse=False,
)
checkpoint = torch.load(check_add, map_location='cpu')
model.load_state_dict(checkpoint)
model.cuda()
dummy_input = torch.randn(1, 3, 224, 224, device='cuda')
input_names = ["input_0"]
output_names = ["output_0"]
torch.onnx.export(model, dummy_input, 'QAT/quant_LeVit-QAT.onnx',
verbose=False, opset_version=13,
input_names=input_names,
output_names=output_names,
do_constant_folding=False,
)
if __name__ == '__main__':
AutoExport('QAT/quant_LeVit-calibrated45.pth')