Java中的异或神经网络 [英] XOR Neural Network in Java

查看:28
本文介绍了Java中的异或神经网络的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试为 Java 中的 XOR 函数实现和训练一个带有反向传播的五神经元神经网络.我的代码(请原谅它的丑陋):

I'm trying to implement and train a five neuron neural network with back propagation for the XOR function in Java. My code (please excuse it's hideousness):

public class XORBackProp {

private static final int MAX_EPOCHS = 500;

//weights
private static double w13, w23, w14, w24, w35, w45;
private static double theta3, theta4, theta5;
//neuron outputs
private static double gamma3, gamma4, gamma5;
//neuron error gradients
private static double delta3, delta4, delta5;
//weight corrections
private static double dw13, dw14, dw23, dw24, dw35, dw45, dt3, dt4, dt5;
//learning rate
private static double alpha = 0.1;
private static double error;
private static double sumSqrError;
private static int epochs = 0;
private static boolean loop = true;

private static double sigmoid(double exponent)
{
    return (1.0/(1 + Math.pow(Math.E, (-1) * exponent)));
}

private static void activateNeuron(int x1, int x2, int gd5)
{
    gamma3 = sigmoid(x1*w13 + x2*w23 - theta3);
    gamma4 = sigmoid(x1*w14 + x2*w24 - theta4);
    gamma5 = sigmoid(gamma3*w35 + gamma4*w45 - theta5);

    error = gd5 - gamma5;

    weightTraining(x1, x2);
}

private static void weightTraining(int x1, int x2)
{
    delta5 = gamma5 * (1 - gamma5) * error;
    dw35 = alpha * gamma3 * delta5;
    dw45 = alpha * gamma4 * delta5;
    dt5 = alpha * (-1) * delta5;

    delta3 = gamma3 * (1 - gamma3) * delta5 * w35;
    delta4 = gamma4 * (1 - gamma4) * delta5 * w45;

    dw13 = alpha * x1 * delta3;
    dw23 = alpha * x2 * delta3;
    dt3 = alpha * (-1) * delta3;
    dw14 = alpha * x1 * delta4;
    dw24 = alpha * x2 * delta4;
    dt4 = alpha * (-1) * delta4;

    w13 = w13 + dw13;
    w14 = w14 + dw14;
    w23 = w23 + dw23;
    w24 = w24 + dw24;
    w35 = w35 + dw35;
    w45 = w45 + dw45;
    theta3 = theta3 + dt3;
    theta4 = theta4 + dt4;
    theta5 = theta5 + dt5;
}

public static void main(String[] args)
{

    w13 = 0.5;
    w14 = 0.9;
    w23 = 0.4;
    w24 = 1.0;
    w35 = -1.2;
    w45 = 1.1;
    theta3 = 0.8;
    theta4 = -0.1;
    theta5 = 0.3;

    System.out.println("XOR Neural Network");

    while(loop)
    {
        activateNeuron(1,1,0);
        sumSqrError = error * error;
        activateNeuron(0,1,1);
        sumSqrError += error * error;
        activateNeuron(1,0,1);
        sumSqrError += error * error;
        activateNeuron(0,0,0);
        sumSqrError += error * error;

        epochs++;

        if(epochs >= MAX_EPOCHS)
        {
            System.out.println("Learning will take more than " + MAX_EPOCHS + " epochs, so program has terminated.");
            System.exit(0);
        }

        System.out.println(epochs + " " + sumSqrError);

        if (sumSqrError < 0.001)
        {
            loop = false;
        }
    }
}
}

如果有帮助,这里是一个网络图.

If it helps any, here's a diagram of the network.

所有权重和学习率的初始值直接取自我教科书中的一个例子.目标是训练网络直到误差平方和小于 0.001.教科书还给出了第一次迭代(1,1,0)后所有权重的值,我已经测试了我的代码,其结果与教科书的结果完美匹配.但根据这本书,这应该只需要 224 个 epoch 就可以收敛.但是当我运行它时,它总是达到 MAX_EPOCHS,除非它设置为几千.我做错了什么?

The initial values for all the weights and the learning rate are taken straight from an example in my textbook. The goal is to train the network until the sum of the squared errors is less than .001. The textbook also gives the values of all the weights after the first iteration (1,1,0) and I've tested my code and its results match the textbook's results perfectly. But according to the book, this should only take 224 epochs to converge. But when I run it, it always reaches MAX_EPOCHS unless it is set to several thousand. What am I doing wrong?

推荐答案

    //Add this in the constants declaration section.
    private static double alpha = 3.8, g34 = 0.13, g5 = 0.21;

    // Add this in activate neuron
    gamma3 = sigmoid(x1 * w13 + x2 * w23 - theta3);
    gamma4 = sigmoid(x1 * w14 + x2 * w24 - theta4);        
    if (gamma3 > 1 - g34 ) {gamma3 = 1;}
    if (gamma3 < g34) {gamma3 = 0;}
    if (gamma4 > 1- g34) {gamma4 = 1;}
    if (gamma4 < g34) {gamma4 = 0;}   
    gamma5 = sigmoid(gamma3 * w35 + gamma4 * w45 - theta5);
    if (gamma5 > 1 - g5) {gamma5 = 1;}
    if (gamma5 < g5) {gamma5 = 0;}

ANN 应该在 66 次迭代中学习,但处于发散的边缘.

ANN should learn in 66 iterations, but is on the brink of divergence.

这篇关于Java中的异或神经网络的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆