Python模型加载案例有哪些?

wen python案例 3

本文目录导读:

  1. PyTorch 模型加载(最常用)
  2. TensorFlow / Keras 模型加载
  3. Hugging Face Transformers(NLP 模型)
  4. Scikit-learn 模型加载(传统机器学习)
  5. ONNX 模型加载(跨平台/跨框架推理)
  6. TFLite(移动端/边缘设备)
  7. 特殊情况:多 GPU 模型加载

Python模型加载案例非常丰富,主要取决于你使用的框架(如PyTorch、TensorFlow、Scikit-learn)、模型格式(如.pth、.h5、.onnx)以及部署场景(推理、迁移学习、服务化)。

下面按常用框架分类,列举几个典型的加载案例及代码示例。


PyTorch 模型加载(最常用)

PyTorch 通常推荐保存模型的状态字典state_dict),加载时需先实例化相同的模型结构。

案例 A:加载完整模型(包含结构,不推荐用于生产)

import torch
import torchvision.models as models
# 1. 保存时使用 torch.save(model, 'model.pth')
# 2. 加载
model = torch.load('model.pth')  
model.eval()  # 切换到评估模式

案例 B:加载 state_dict(标准做法)

import torch
import torchvision.models as models
# 1. 先定义与训练时完全相同的模型结构
model = models.resnet18(weights=None, num_classes=10)
# 2. 加载权重
state_dict = torch.load('model_state_dict.pth', map_location='cpu')  # 防止GPU不存在
model.load_state_dict(state_dict)
# 3. 切换到评估模式
model.eval()
# 4. 推理
dummy_input = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = model(dummy_input)

案例 C:加载 Checkpoint(包含优化器状态,用于继续训练)

checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

TensorFlow / Keras 模型加载

案例 A:加载整个模型(.h5 或 SavedModel 格式)

import tensorflow as tf
# 加载保存的完整模型(结构+权重+优化器)
model = tf.keras.models.load_model('my_model.h5')
# 推理
import numpy as np
dummy_input = np.random.rand(1, 224, 224, 3)
predictions = model.predict(dummy_input)

案例 B:仅加载权重(需先定义模型结构)

# 先定义相同结构
model = tf.keras.Sequential([...])  # 或使用函数式API
# 加载权重
model.load_weights('model_weights.h5')
# 编译(如果需要进行评估或继续训练)
model.compile(optimizer='adam', loss='categorical_crossentropy')

Hugging Face Transformers(NLP 模型)

案例 A:从本地文件加载预训练模型

from transformers import AutoModel, AutoTokenizer
# 假设已经下载或手动保存了模型到本地目录
model = AutoModel.from_pretrained('./my_bert_model/')
tokenizer = AutoTokenizer.from_pretrained('./my_bert_model/')
# 推理
inputs = tokenizer("Hello, world!", return_tensors="pt")
outputs = model(**inputs)

案例 B:通过模型库名称加载(自动下载)

from transformers import pipeline
# 一行加载并推理
classifier = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
result = classifier("I love this movie!")

Scikit-learn 模型加载(传统机器学习)

import joblib
# 加载之前用 joblib.dump 保存的模型
model = joblib.load('trained_model.pkl')
# 推理
predictions = model.predict(X_test)
probabilities = model.predict_proba(X_test)

ONNX 模型加载(跨平台/跨框架推理)

ONNX 允许你在不同框架(如 PyTorch → ONNX → TensorRT)之间交换模型。

import onnxruntime as ort
import numpy as np
# 创建一个 ONNX Runtime 推理会话
ort_session = ort.InferenceSession('model.onnx')
# 获取输入输出名称
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
# 准备输入数据(必须与 ONNX 模型要求的 shape 一致)
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
# 推理
outputs = ort_session.run([output_name], {input_name: input_data})

TFLite(移动端/边缘设备)

import tensorflow as tf
import numpy as np
# 加载 TFLite 模型
interpreter = tf.lite.Interpreter(model_path='model.tflite')
interpreter.allocate_tensors()
# 获取输入输出 tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# 填充输入
input_data = np.random.randn(1, 224, 224, 3).astype(np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
# 推理
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])

特殊情况:多 GPU 模型加载

如果模型是使用 DataParallel 训练保存的,权重键名会带有 module. 前缀。

import torch
# 方法1:移除前缀
state_dict = torch.load('model.pth')
new_state_dict = {}
for k, v in state_dict.items():
    name = k.replace("module.", "")  # 去掉module.前缀
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)
# 方法2:直接加载到单个GPU或CPU(PyTorch 1.1+ 会自动处理)
model = torch.nn.DataParallel(model)  # 重新包装
model.load_state_dict(torch.load('model.pth'))

场景 推荐格式 加载方式
PyTorch 训练/推理 .pth, .pt torch.load() + model.load_state_dict()
TensorFlow 部署 .h5 或 SavedModel tf.keras.models.load_model()
传统机器学习 (sklearn) .pkl joblib.load()
跨框架/移动端 .onnx, .tflite onnxruntime.InferenceSession() / tf.lite.Interpreter()
NLP 模型 (Hugging Face) config.json, pytorch_model.bin AutoModel.from_pretrained()

小提示: 生产环境中,建议将模型转换为 ONNXTensorRT(针对NVIDIA GPU)以获得最佳性能和可移植性。

标签: ONNX PyTorch

抱歉,评论功能暂时关闭!