如何在 nodejs (tensorflow.js) 中训练模型? [英] How to train a model in nodejs (tensorflow.js)?

查看:39
本文介绍了如何在 nodejs (tensorflow.js) 中训练模型?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想做一个图像分类器,但我不会python.Tensorflow.js 使用我熟悉的 javascript.可以用它训练模型吗?这样做的步骤是什么?坦率地说,我不知道从哪里开始.

I want to make a image classifier, but I don't know python. Tensorflow.js works with javascript, which I am familiar with. Can models be trained with it and what would be the steps to do so? Frankly I have no clue where to start.

我唯一想到的是如何加载mobilenet",它显然是一组预先训练好的模型,并用它对图像进行分类:

The only thing I figured out is how to load "mobilenet", which apparently is a set of pre-trained models, and classify images with it:

const tf = require('@tensorflow/tfjs'),
      mobilenet = require('@tensorflow-models/mobilenet'),
      tfnode = require('@tensorflow/tfjs-node'),
      fs = require('fs-extra');

const imageBuffer = await fs.readFile(......),
      tfimage = tfnode.node.decodeImage(imageBuffer),
      mobilenetModel = await mobilenet.load();  

const results = await mobilenetModel.classify(tfimage);

可行,但对我来说没有用,因为我想使用带有我创建的标签的图像来训练我自己的模型.

which works, but it's no use to me because I want to train my own model using my images with labels that I create.

========================

=======================

假设我有一堆图片和标签.我如何使用它们来训练模型?

Say I have a bunch of images and labels. How do I use them to train a model?

const myData = JSON.parse(await fs.readFile('files.json'));

for(const data of myData){
  const image = await fs.readFile(data.imagePath),
        labels = data.labels;

  // how to train, where to pass image and labels ?

}

推荐答案

首先,需要将图像转换为张量.第一种方法是创建一个包含所有特征的张量(分别是一个包含所有标签的张量).只有当数据集包含很少的图像时,这才应该是可行的方法.

First of all, the images needs to be converted to tensors. The first approach would be to create a tensor containing all the features (respectively a tensor containing all the labels). This should the way to go only if the dataset contains few images.

  const imageBuffer = await fs.readFile(feature_file);
  tensorFeature = tfnode.node.decodeImage(imageBuffer) // create a tensor for the image

  // create an array of all the features
  // by iterating over all the images
  tensorFeatures = tf.stack([tensorFeature, tensorFeature2, tensorFeature3])

标签将是一个数组,指示每个图像的类型

The labels would be an array indicating the type of each image

 labelArray = [0, 1, 2] // maybe 0 for dog, 1 for cat and 2 for birds

现在需要创建标签的热编码

One needs now to create a hot encoding of the labels

 tensorLabels = tf.oneHot(tf.tensor1d(labelArray, 'int32'), 3);

一旦有了张量,就需要创建用于训练的模型.这是一个简单的模型.

Once there is the tensors, one would need to create the model for training. Here is a simple model.

const model = tf.sequential();
model.add(tf.layers.conv2d({
  inputShape: [height, width, numberOfChannels], // numberOfChannels = 3 for colorful images and one otherwise
  filters: 32,
  kernelSize: 3,
  activation: 'relu',
}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({units: 3, activation: 'softmax'}));

然后可以训练模型

model.fit(tensorFeatures, tensorLabels)

如果数据集包含大量图像,则需要创建一个 tfDataset.这个 answer 讨论了原因.

If the dataset contains a lot of images, one would need to create a tfDataset instead. This answer discusses why.

const genFeatureTensor = image => {
      const imageBuffer = await fs.readFile(feature_file);
      return tfnode.node.decodeImage(imageBuffer)
}

const labelArray = indice => Array.from({length: numberOfClasses}, (_, k) => k === indice ? 1 : 0)

function* dataGenerator() {
  const numElements = numberOfImages;
  let index = 0;
  while (index < numFeatures) {
    const feature = genFeatureTensor(imagePath) ;
    const label = tf.tensor1d(labelArray(classImageIndex))
    index++;
    yield {xs: feature, ys: label};
  }
}

const ds = tf.data.generator(dataGenerator);

并使用model.fitDataset(ds)训练模型

以上是在nodejs中训练的.要在浏览器中做这样的处理,genFeatureTensor 可以写成如下:

The above is for training in nodejs. To do such a processing in the browser, genFeatureTensor can be written as follow:

function load(url){
  return new Promise((resolve, reject) => {
    const im = new Image()
        im.crossOrigin = 'anonymous'
        im.src = 'url'
        im.onload = () => {
          resolve(im)
        }
   })
}

genFeatureTensor = image => {
  const img = await loadImage(image);
  return tf.browser.fromPixels(image);
}

需要注意的一点是,进行繁重的处理可能会阻塞浏览器中的主线程.这就是网络工作者发挥作用的地方.

One word of caution is that doing heavy processing might block the main thread in the browser. This is where web workers come into play.

这篇关于如何在 nodejs (tensorflow.js) 中训练模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆