llvm-project/llvm/lib/Analysis/models/saved-model-to-tflite.py
Tobias Hieta b71edfaa4e
[NFC][Py Reformat] Reformat python files in llvm
This is the first commit in a series that will reformat
all the python files in the LLVM repository.

Reformatting is done with `black`.

See more information here:

https://discourse.llvm.org/t/rfc-document-and-standardize-python-code-style

Reviewed By: jhenderson, JDevlieghere, MatzeB

Differential Revision: https://reviews.llvm.org/D150545
2023-05-17 10:48:52 +02:00

38 lines
1.0 KiB
Python

"""Convert a saved model to tflite model.
Usage: python3 saved-model-to-tflite.py <mlgo saved_model_dir> <tflite dest_dir>
The <tflite dest_dir> will contain:
model.tflite: this is the converted saved model
output_spec.json: the output spec, copied from the saved_model dir.
"""
import tensorflow as tf
import os
import sys
from tf_agents.policies import greedy_policy
def main(argv):
assert len(argv) == 3
sm_dir = argv[1]
tfl_dir = argv[2]
tf.io.gfile.makedirs(tfl_dir)
tfl_path = os.path.join(tfl_dir, "model.tflite")
converter = tf.lite.TFLiteConverter.from_saved_model(sm_dir)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
]
tfl_model = converter.convert()
with tf.io.gfile.GFile(tfl_path, "wb") as f:
f.write(tfl_model)
json_file = "output_spec.json"
src_json = os.path.join(sm_dir, json_file)
if tf.io.gfile.exists(src_json):
tf.io.gfile.copy(src_json, os.path.join(tfl_dir, json_file))
if __name__ == "__main__":
main(sys.argv)