判断一个模型是pytorch模型还是tensorflow模型还是scikit模型 [英] Determine whether a model is pytorch model or a tensorflow model or scikit model

查看:166
本文介绍了判断一个模型是pytorch模型还是tensorflow模型还是scikit模型的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

如果我想确定模型的类型,即以编程方式从哪个框架构建模型,是否有办法做到?
我有一些序列化的模型(例如,泡菜文件).为简单起见,假设我的模型可以是tensorflow的,pytorch的或scikit Learn的.如何以编程方式确定这3个中的哪一个是?

If I want to determine the type of model i.e. from which framework was it made programmatically, is there a way to do that?
I have a model in some serialized manner(Eg. a pickle file). For simplicity purposes, assume that my model can be either tensorflow's, pytorch's or scikit learn's. How can I determine programmatically which one of these 3 is the one?

推荐答案

AFAIK,我从未听说过将Tensorflow/Keras和Pytorch模型与pickle或joblib一起保存-这些框架提供了自己的保存和保存功能.正在加载模型:请参见SO线程 Tensorflow:如何保存/恢复模型?在PyTorch中保存经过训练的模型的最佳方法?.此外,在尝试保存Tensorflow模型时,有一个 Github线程报告了各种问题.用pickle和joblib.

AFAIK, I have never heard of Tensorflow/Keras and Pytorch models to be saved with pickle or joblib - these frameworks provide their own functionality for saving & loading models: see the SO threads Tensorflow: how to save/restore a model? and Best way to save a trained model in PyTorch?. Additionally, there is a Github thread reporting all kinds of issues when trying to save Tensorflow models with pickle and joblib.

鉴于,如果您向模型加载了咸菜,那么使用 type(model) model 来查看其类型是微不足道的.这是使用scikit-learn线性回归模型的简短演示:

Given that, if you have loaded a model with, say, pickle, it is trivial to see what type it is using type(model) and model. Here is a short demonstration with a scikit-learn linear regression model:

import numpy as np
from sklearn.linear_model import LinearRegression

X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
y = np.dot(X, np.array([1, 2])) + 3
reg = LinearRegression()
reg.fit(X, y)

# save it

import pickle

filename = 'model1.pkl'
pickle.dump(reg, open(filename, 'wb'))

现在,加载模型:

loaded_model = pickle.load(open(filename, 'rb'))

type(loaded_model)
# sklearn.linear_model._base.LinearRegression

loaded_model
# LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None, normalize=False)

这还将与XGBoost,LightGBM,CatBoost等框架一起使用.

This will also work with frameworks like XGBoost, LightGBM, CatBoost etc.

这篇关于判断一个模型是pytorch模型还是tensorflow模型还是scikit模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆