TensorFlow - Saver.restore 不恢复所有参数 [英] TensorFlow - Saver.restore not restoring all parameters

查看:66
本文介绍了TensorFlow - Saver.restore 不恢复所有参数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我训练双向 LSTM 类型的 RNN 将近 24 小时,由于错误的振荡,我决定在允许它继续训练之前减少学习.由于模型是在每个 epoch 使用 Saver.save(sess,file) 保存的,因此我在 CTC 损失最小化到大约 115 的情况下终止了训练.

I was training Bidirectional LSTM type RNN for nearly 24 hours, and due to oscillation in the error I decided to decrease the learning before allowing it to continue training. Since the model is saved using Saver.save(sess,file) at every epoch, I terminated the training with the CTC Loss having minimised to approximately 115.

现在恢复模型后,我得到的初始错误率大约在 162 左右,这与我在第 7 个 epoch 中得到的错误率流不一致,也是我在第一个 epoch 中得到的错误率.所以我的印象是恢复"功能不起作用,或者如果它起作用,那么一定是其他东西不允许它生效.

Now after restoring the model, the initial error rate I am getting is somewhere around 162, which is inconsistent with the flow of error rate I was getting in 7th epoch, and is also what I got in the first epoch. So it is my impression that either "restore" function is not working or if it is working, then there must be something else that is not allowing it to take effect.

这是我的代码:

    graph = tf.Graph()
    with graph.as_default():
        # Graph creation
        graph_start = time.time()
        seq_inputs = tf.placeholder(tf.float32, shape=     [None,batch_size,frame_length], name="sequence_inputs")
        seq_lens = tf.placeholder(shape=[batch_size],dtype=tf.int32)
        seq_inputs = seq_bn(seq_inputs,seq_lens)

        initializer = tf.truncated_normal_initializer(mean=0,stddev=0.1)
        forward = tf.nn.rnn_cell.LSTMCell(num_units=num_units,
                                          num_proj = hidden_size,
                                          use_peepholes=use_peephole,
                                          initializer=initializer,
                                          state_is_tuple=True)

        forward = tf.nn.rnn_cell.MultiRNNCell([forward] * n_layers, state_is_tuple=True)

        backward = tf.nn.rnn_cell.LSTMCell(num_units=num_units,
                                           num_proj= hidden_size,
                                           use_peepholes=use_peephole,
                                           initializer=initializer,
                                           state_is_tuple=True)

        backward = tf.nn.rnn_cell.MultiRNNCell([backward] * n_layers, state_is_tuple=True)

        [fw_out,bw_out], _ = tf.nn.bidirectional_dynamic_rnn(cell_fw=forward, cell_bw=backward, inputs=seq_inputs,time_major=True, dtype=tf.float32,                                               sequence_length=tf.cast(seq_lens,tf.int64))


        # Batch normalize forward output
        mew,var_ = tf.nn.moments(fw_out,axes=[0])
        fw_out = tf.nn.batch_normalization(fw_out,mew,var_,0.1,1,1e-6)
        # fw_out = seq_bn(fw_out,seq_lens)

        # Batch normalize backward output
        mew,var_ = tf.nn.moments(bw_out,axes=[0])
        bw_out = tf.nn.batch_normalization(bw_out,mew,var_,0.1,1,1e-6)
        # bw_out = seq_bn(bw_out,seq_lens)

        # Reshaping forward, and backward outputs for affine transformation
        fw_out = tf.reshape(fw_out,[-1,hidden_size])
        bw_out = tf.reshape(bw_out,[-1,hidden_size])

        # Linear Layer params
        W_fw = tf.Variable(tf.truncated_normal(shape=[hidden_size,n_chars],stddev=np.sqrt(2.0 / (hidden_size))))
        W_bw = tf.Variable(tf.truncated_normal(shape=[hidden_size,n_chars],stddev=np.sqrt(2.0 / (hidden_size))))
        b_out = tf.constant(0.1,shape=[n_chars])

        # Perform an affine transformation
        logits =  tf.add(tf.add(tf.matmul(fw_out,W_fw),tf.matmul(bw_out,W_bw)),b_out)
        logits = tf.reshape(logits,[-1,batch_size,n_chars])

        # Use CTC Beam Search Decoder to decode pred string from the prob map
        decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_lens)

        # Target params
        indices = tf.placeholder(dtype=tf.int64, shape=[None,2])
        values = tf.placeholder(dtype=tf.int32, shape=[None])
        shape = tf.placeholder(dtype=tf.int64,shape=[2])
        # Make targets
        targets = tf.SparseTensor(indices,values,shape)

        # Compute Loss
        loss = tf.reduce_mean(tf.nn.ctc_loss(logits, targets, seq_lens))
        # Compute error rate based on edit distance
        predicted = tf.to_int32(decoded[0])
        error_rate = tf.reduce_sum(tf.edit_distance(predicted,targets,normalize=False))/ \
         tf.to_float(tf.size(targets.values))    

        tvars = tf.trainable_variables()
        grad, _ = tf.clip_by_global_norm(tf.gradients(loss,tvars),max_grad_norm)
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr,momentum=momentum)
        train_step = optimizer.apply_gradients(zip(grad,tvars))
        graph_end = time.time()
        print("Time elapsed for creating graph: %.3f"%(round(graph_end-graph_start,3)))
        # steps per epoch
        start_time = 0
        steps = int(np.ceil(len(data_train.files)/batch_size))

        loss_tr = []
        log_tr = []
        loss_vl = []
        log_vl = []
        err_tr = []
        err_vl = []
        saver = tf.train.Saver()
        with tf.Session(config=config) as sess:
            #sess.run(tf.initialize_all_variables())
            checkpt_path = tf.train.latest_checkpoint(checkpoint_dir)
            print(saver.restore(sess,checkpt_path))
            print("Model restore from 7th epoch 188th step")
            feed = None
            epoch = None
            step = None
            try:
                for epoch in range(7,epochs+1):
                    if epoch==7:
                       initial_step = 189
                    else:
                       initial_step = 0
                    transcript = []
                    loss_val = 0
                    l_pr = 0
                    start_time = time.time()
                    for step in range(initial_step,steps):
                        train_data, transcript, \
                        targ_indices, targ_values, \
                        targ_shape, n_frames = data_train.next_batch()
                        n_frames = np.reshape(n_frames,[-1])
                        feed = {seq_inputs: train_data, indices:targ_indices, values:targ_values, shape:targ_shape, seq_lens:n_frames}
                        del train_data,targ_indices,targ_values,targ_shape,n_frames

                        # Evaluate loss value, decoded transcript, and log probability
                        _,loss_val,deco,l_pr,err_rt_tr = sess.run([train_step,loss,decoded,log_prob,error_rate],
                                                            feed_dict=feed)
                        del feed
                        loss_tr.append(loss_val)
                        log_tr.append(l_pr)
                        err_tr.append(err_rt_tr)

                        # On validation set
                        val_data, val_transcript, \
                        targ_indices, targ_values, \
                        targ_shape, n_frames = data_val.next_batch()
                        n_frames = np.reshape(n_frames, [-1])
                        feed = {seq_inputs: val_data, indices: targ_indices,values: targ_values, shape: targ_shape, seq_lens: n_frames}
                        del val_data, val_transcript,targ_indices,targ_values,targ_shape,n_frames
                    vl_loss, l_val_pr, err_rt_vl = sess.run([loss, log_prob, error_rate], feed_dict=feed)
                        del feed
                        loss_vl.append(vl_loss)
                        log_vl.append(l_val_pr)
                        err_vl.append(err_rt_vl)
                        print("epoch %d, step: %d, tr_loss: %.2f, vl_loss: %.2f, tr_err: %.2f, vl_err: %.2f"
                          % (epoch, step, np.mean(loss_tr), np.mean(loss_vl), err_rt_tr, err_rt_vl))

                    end_time = time.time()
                    elapsed = round(end_time - start_time, 3)

                    # On training set
                    # Select a random index within batch_size
                    sample_index = np.random.randint(0, batch_size)

                    # Fetch the target transcript
                    actual_str = [data_train.reverse_map[i] for i in transcript[sample_index]]

                    # Fetch the decoded path from probability map
                    pred_sparse = tf.SparseTensor(deco[0].indices, deco[0].values, deco[0].shape)
                    pred_dense = tf.sparse_tensor_to_dense(pred_sparse)
                    ans = pred_dense.eval()
                    #pred = [data_train.reverse_map[i] for i in ans[sample_index, :]]
                    pred = []
                    for i in ans[sample_index,:]:
                        if i == n_chars-1:
                            pred.append(data_train.reverse_map[0])
                        else:
                            pred.append(data_train.reverse_map[i])
                    print("time_elapsed for 200 steps: %.3f, " % (elapsed))
                    if epoch%2 == 0:
                        print("Sample mini-batch results: \n" \
                              "predicted string: ", np.array(pred))
                        print("actual string: ", np.array(actual_str))
                    print("On training set, the loss: %.2f, log_pr: %.3f, error rate %.3f:"% (loss_val, np.mean(l_pr), err_rt_tr))
                    print("On validation set, the loss: %.2f, log_pr: %.3f, error rate: %.3f" % (vl_loss, np.mean(l_val_pr), err_rt_vl))

                    # Save the trainable parameters after the end of an epoch
                    if epoch > 7:
                        path = saver.save(sess, 'model_%d' % epoch)
                    print("Session saved at: %s" % path)
                    np.save(results_fn, np.array([loss_tr, log_tr, loss_vl, log_vl, err_tr, err_vl], dtype=np.object))
            except (KeyboardInterrupt, SystemExit, Exception), e:
                print("Error/Interruption: %s" % str(e))
                exc_type, exc_obj, exc_tb = sys.exc_info()
                print("Line no: %d" % exc_tb.tb_lineno)
                if epoch > 7:
                    print("Saving model: %s" % saver.save(sess, 'Last.cpkt'))
                print("Current batch: %d" % data_train.b_id)
                print("Current epoch: %d" % epoch)
                print("Current step: %d"%step)
                np.save(results_fn, np.array([loss_tr, log_tr, loss_vl, log_vl, err_tr, err_vl], dtype=np.object))
                print("Clossing TF Session...")
                sess.close()
                print("Terminating Program...")
                sys.exit(0)

推荐答案

我认为你需要为每个 epoch 重新初始化你的累加器.

I think you need to re-initialize your accumulators for each epoch.

所以这些必须放在开头,在循环内.

So these ones must be put at the beginning, inside the loop.

loss_tr = []
log_tr = []
loss_vl = []
log_vl = []
err_tr = []
err_vl = []

这篇关于TensorFlow - Saver.restore 不恢复所有参数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆