如何在 nodejs (tensorflow.js) 中训练模型? [英] How to train a model in nodejs (tensorflow.js)?
问题描述
我想做一个图像分类器,但我不会python.Tensorflow.js 使用我熟悉的 javascript.可以用它训练模型吗?这样做的步骤是什么?坦率地说,我不知道从哪里开始.
I want to make a image classifier, but I don't know python. Tensorflow.js works with javascript, which I am familiar with. Can models be trained with it and what would be the steps to do so? Frankly I have no clue where to start.
我唯一想到的是如何加载mobilenet",它显然是一组预先训练好的模型,并用它对图像进行分类:
The only thing I figured out is how to load "mobilenet", which apparently is a set of pre-trained models, and classify images with it:
const tf = require('@tensorflow/tfjs'),
mobilenet = require('@tensorflow-models/mobilenet'),
tfnode = require('@tensorflow/tfjs-node'),
fs = require('fs-extra');
const imageBuffer = await fs.readFile(......),
tfimage = tfnode.node.decodeImage(imageBuffer),
mobilenetModel = await mobilenet.load();
const results = await mobilenetModel.classify(tfimage);
可行,但对我来说没有用,因为我想使用带有我创建的标签的图像来训练我自己的模型.
which works, but it's no use to me because I want to train my own model using my images with labels that I create.
========================
=======================
假设我有一堆图片和标签.我如何使用它们来训练模型?
Say I have a bunch of images and labels. How do I use them to train a model?
const myData = JSON.parse(await fs.readFile('files.json'));
for(const data of myData){
const image = await fs.readFile(data.imagePath),
labels = data.labels;
// how to train, where to pass image and labels ?
}
推荐答案
首先,需要将图像转换为张量.第一种方法是创建一个包含所有特征的张量(分别是一个包含所有标签的张量).只有当数据集包含很少的图像时,这才应该是可行的方法.
First of all, the images needs to be converted to tensors. The first approach would be to create a tensor containing all the features (respectively a tensor containing all the labels). This should the way to go only if the dataset contains few images.
const imageBuffer = await fs.readFile(feature_file);
tensorFeature = tfnode.node.decodeImage(imageBuffer) // create a tensor for the image
// create an array of all the features
// by iterating over all the images
tensorFeatures = tf.stack([tensorFeature, tensorFeature2, tensorFeature3])
标签将是一个数组,指示每个图像的类型
The labels would be an array indicating the type of each image
labelArray = [0, 1, 2] // maybe 0 for dog, 1 for cat and 2 for birds
现在需要创建标签的热编码
One needs now to create a hot encoding of the labels
tensorLabels = tf.oneHot(tf.tensor1d(labelArray, 'int32'), 3);
一旦有了张量,就需要创建用于训练的模型.这是一个简单的模型.
Once there is the tensors, one would need to create the model for training. Here is a simple model.
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [height, width, numberOfChannels], // numberOfChannels = 3 for colorful images and one otherwise
filters: 32,
kernelSize: 3,
activation: 'relu',
}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({units: 3, activation: 'softmax'}));
然后可以训练模型
model.fit(tensorFeatures, tensorLabels)
如果数据集包含大量图像,则需要创建一个 tfDataset.这个 answer 讨论了原因.
If the dataset contains a lot of images, one would need to create a tfDataset instead. This answer discusses why.
const genFeatureTensor = image => {
const imageBuffer = await fs.readFile(feature_file);
return tfnode.node.decodeImage(imageBuffer)
}
const labelArray = indice => Array.from({length: numberOfClasses}, (_, k) => k === indice ? 1 : 0)
function* dataGenerator() {
const numElements = numberOfImages;
let index = 0;
while (index < numFeatures) {
const feature = genFeatureTensor(imagePath) ;
const label = tf.tensor1d(labelArray(classImageIndex))
index++;
yield {xs: feature, ys: label};
}
}
const ds = tf.data.generator(dataGenerator);
并使用model.fitDataset(ds)
训练模型
以上是在nodejs中训练的.要在浏览器中做这样的处理,genFeatureTensor
可以写成如下:
The above is for training in nodejs. To do such a processing in the browser, genFeatureTensor
can be written as follow:
function load(url){
return new Promise((resolve, reject) => {
const im = new Image()
im.crossOrigin = 'anonymous'
im.src = 'url'
im.onload = () => {
resolve(im)
}
})
}
genFeatureTensor = image => {
const img = await loadImage(image);
return tf.browser.fromPixels(image);
}
需要注意的一点是,进行繁重的处理可能会阻塞浏览器中的主线程.这就是网络工作者发挥作用的地方.
One word of caution is that doing heavy processing might block the main thread in the browser. This is where web workers come into play.
这篇关于如何在 nodejs (tensorflow.js) 中训练模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!