-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport_model.py
29 lines (22 loc) · 926 Bytes
/
export_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import tensorflow as tf
import argparse
from model import DeeplabV3Plus, UNET
def generate_saved_model(model_weights, export_path):
model = DeeplabV3Plus(512)
print("LOADING THE MODEL")
model.load_weights(model_weights)
print("EXPORTING THE MODEL")
model.export(export_path)
def quantize_model(saved_model_path, dataset_generator):
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = dataset_generator
converter.convert()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='Export model',
description='Export a tf checkpoint to savedModel and/or quantize the model')
parser.add_argument('model_weights')
parser.add_argument('export_path')
args = parser.parse_args()
generate_saved_model(args.model_weights, args.export_path)