| import os |
| import json |
| import shutil |
|
|
| from optimum.exporters.onnx import main_export |
| import onnx |
| from onnxconverter_common import float16 |
| import onnxruntime as rt |
| from onnxruntime.tools.onnx_model_utils import * |
| from onnxruntime.quantization import quantize_dynamic, QuantType |
| import huggingface_hub |
|
|
| def add_mean_pooling(input_model, output_model, op, IR, output_embeddings_number): |
| model = onnx.load(input_model) |
| model_ir8 = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) |
| |
| minus_one_axis = onnx.helper.make_tensor( |
| name = "minus_one_axis", |
| data_type = onnx.TensorProto.INT64, |
| dims = [1], |
| vals = [-1]) |
| |
| model_ir8.graph.initializer.append(minus_one_axis) |
| |
| mask_clip_lower_limit = onnx.helper.make_tensor( |
| name = "mask_clip_lower_limit", |
| data_type = onnx.TensorProto.FLOAT, |
| dims = [1], |
| vals = [1e-9]) |
| |
| model_ir8.graph.initializer.append(mask_clip_lower_limit) |
| |
| sum_one_axis = onnx.helper.make_tensor( |
| name = "sum_one_axis", |
| data_type = onnx.TensorProto.INT64, |
| dims = [1], |
| vals = [1]) |
| |
| model_ir8.graph.initializer.append(sum_one_axis) |
| |
| attention_mask_cast_op = onnx.helper.make_node( |
| "Cast", |
| inputs=["attention_mask"], |
| outputs=["attention_mask_fp32"], |
| to=onnx.TensorProto.FLOAT |
| ) |
| |
| model_ir8.graph.node.append(attention_mask_cast_op) |
| |
| expand_dims_op = onnx.helper.make_node( |
| "Unsqueeze", |
| inputs=["attention_mask_fp32", "minus_one_axis"], |
| outputs=["unsqueezed_attention_mask"], |
| ) |
| |
| model_ir8.graph.node.append(expand_dims_op) |
| |
| shape_op = onnx.helper.make_node( |
| "Shape", |
| inputs = ["last_hidden_state"], |
| outputs = ["last_hidden_state_shape"] |
| ) |
| |
| model_ir8.graph.node.append(shape_op) |
| |
| broadcast_to_op = onnx.helper.make_node( |
| "Expand", |
| inputs=["unsqueezed_attention_mask", "last_hidden_state_shape"], |
| outputs=["expanded_attention_mask"], |
| ) |
| |
| model_ir8.graph.node.append(broadcast_to_op) |
| |
| multiply_op = onnx.helper.make_node( |
| "Mul", |
| inputs=["last_hidden_state", "expanded_attention_mask"], |
| outputs=["last_hidden_state_x_expanded_attention_mask"], |
| ) |
| |
| model_ir8.graph.node.append(multiply_op) |
| |
| sum_embeddings_op = onnx.helper.make_node( |
| "ReduceSum", |
| inputs=["last_hidden_state_x_expanded_attention_mask", "sum_one_axis"], |
| outputs=["sum_last_hidden_state_x_expanded_attention_mask"], |
| ) |
| |
| model_ir8.graph.node.append(sum_embeddings_op) |
| |
| sum_mask_op = onnx.helper.make_node( |
| "ReduceSum", |
| inputs=["expanded_attention_mask", "sum_one_axis"], |
| outputs=["sum_expanded_attention_mask"], |
| ) |
| |
| model_ir8.graph.node.append(sum_mask_op) |
| |
| clip_mask_op = onnx.helper.make_node( |
| "Clip", |
| inputs=["sum_expanded_attention_mask", "mask_clip_lower_limit"], |
| outputs=["clipped_sum_expanded_attention_mask"], |
| ) |
| |
| model_ir8.graph.node.append(clip_mask_op) |
| |
| pooled_embeddings_op = onnx.helper.make_node( |
| "Div", |
| inputs=["sum_last_hidden_state_x_expanded_attention_mask", "clipped_sum_expanded_attention_mask"], |
| outputs=["pooled_embeddings"], |
| |
| ) |
| |
| model_ir8.graph.node.append(pooled_embeddings_op) |
| |
| squeeze_pooled_embeddings_op = onnx.helper.make_node( |
| "Squeeze", |
| inputs=["pooled_embeddings", "sum_one_axis"], |
| outputs=["squeezed_pooled_embeddings"] |
| |
| ) |
| |
| model_ir8.graph.node.append(squeeze_pooled_embeddings_op) |
| |
| normalized_pooled_embeddings_op = onnx.helper.make_node( |
| "Normalizer", |
| domain="ai.onnx.ml", |
| inputs=["squeezed_pooled_embeddings"], |
| outputs=["sentence_embedding"], |
| norm = "L2" |
| ) |
| |
| |
| model_ir8.graph.node.append(normalized_pooled_embeddings_op) |
| |
| sentence_embeddings_output = onnx.helper.make_tensor_value_info( |
| "sentence_embedding", |
| onnx.TensorProto.FLOAT, |
| shape=["batch_size", output_embeddings_number] |
| ) |
| |
| model_ir8.graph.output.append(sentence_embeddings_output) |
| |
| for node in model_ir8.graph.output: |
| if node.name == "last_hidden_state": |
| model_ir8.graph.output.remove(node) |
| |
| model_ir8 = onnx.helper.make_model(model_ir8.graph, ir_version = 8, opset_imports = [op]) |
| |
| onnx.save(model_ir8, output_model, save_as_external_data = False) |
|
|
| |
|
|
| with open('conversion_config.json') as json_file: |
| conversion_config = json.load(json_file) |
|
|
|
|
| model_id = conversion_config["model_id"] |
| number_of_generated_embeddings = conversion_config["number_of_generated_embeddings"] |
| precision_to_filename_map = conversion_config["precision_to_filename_map"] |
| opset = conversion_config["opset"] |
| IR = conversion_config["IR"] |
|
|
| |
| op = onnx.OperatorSetIdProto() |
| op.version = opset |
| |
| |
| if not os.path.exists("onnx"): |
| os.makedirs("onnx") |
|
|
| print("Exporting the main model version") |
| try: |
| main_export(model_name_or_path=model_id, output="./", opset=opset, trust_remote_code=True, task="feature-extraction", dtype="fp32") |
| except: |
| huggingface_hub.hf_hub_download(repo_id=model_id, filename="model.onnx", local_dir="./") |
| |
| |
| if "fp32" in precision_to_filename_map: |
| print("Exporting the fp32 onnx file...") |
| |
| shutil.copyfile('model.onnx', precision_to_filename_map["fp32"]) |
| add_mean_pooling("model.onnx", precision_to_filename_map["fp32"], op, IR, number_of_generated_embeddings) |
| |
| print("Done\n\n") |
|
|
| if "int8" in precision_to_filename_map: |
| print("Quantizing fp32 model to int8...") |
| quantize_dynamic("model.onnx", precision_to_filename_map["int8"], weight_type=QuantType.QInt8) |
| add_mean_pooling( precision_to_filename_map["int8"], precision_to_filename_map["int8"], op, IR, number_of_generated_embeddings) |
| print("Done\n\n") |
| |
| if "uint8" in precision_to_filename_map: |
| print("Quantizing fp32 model to uint8...") |
| quantize_dynamic("model.onnx", precision_to_filename_map["uint8"], weight_type=QuantType.QUInt8) |
| add_mean_pooling( precision_to_filename_map["uint8"], precision_to_filename_map["uint8"], op, IR, number_of_generated_embeddings) |
| print("Done\n\n") |
| |
| os.remove("model.onnx") |
|
|