Tensorflow 2.0:flat_map() 扁平化数据集的数据集返回基数 -2 [英] Tensorflow 2.0: flat_map() to flatten Dataset of Dataset returns cardinality -2

查看:62
本文介绍了Tensorflow 2.0:flat_map() 扁平化数据集的数据集返回基数 -2的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试运行以下代码(如 Tensorflow 文档中给出的)来创建我的数据窗口,然后展平数据集的数据集.

I am trying to run the following code (as given in Tensorflow documentation) to create windows of my data and then flatten the dataset of datasets.

window_size = 5

windows = range_ds.window(window_size, shift=1)
for sub_ds in windows.take(5):
    print(sub_ds)

flat_windows = windows.flat_map(lambda x: x)

问题是 flat_windows.cardinality().numpy() 返回基数为 -2,这在训练期间给我带来了问题.我尝试寻找 set_cardinality 数据集的方法,但找不到任何东西.我还尝试了其他方法来展平数据集的数据集,但还是没有成功.

The problem is that flat_windows.cardinality().numpy() returns cardinality to be -2 which is creating problem for me during training. I tried looking for ways to set_cardinality of a dataset but couldn't find anything. I also tried other ways of flattening a dataset of datasets, but again no success.

Edit-1: 训练的问题在于,当我训练子类模型(如下所示)时,形状未知(在线性和密集层).当我急切地训练模型时,模型训练得很好(通过 tf.config.run_functions_eagerly(True)),但速度很慢.因此,我希望输入数据在模型训练中是已知的.

Edit-1: The problem with the training is that the shape is unknown (at Linear and Dense layers) when I am training a subclass model (given below). The model trains well when I train the model eagerly (through tf.config.run_functions_eagerly(True)) but that is slow. Therefore I want the input data to be known for the model training.

class NeuralNetworkModel(tf.keras.Model): 
    def __init__(self):
        super(NeuralNetworkModel, self).__init__()
        self.encoder = Encoder()        
    
    def train_step(self, inputs):       
        X        = inputs[0]
        Y        = inputs[1] 
        
        with tf.GradientTape() as tape:
            enc_X    = self.encoder(X)
            enc_Y    = self.encoder(Y)    

            # loss:        
            loss   = tf.norm(enc_Y - enc_X, axis = [0, 1], ord = 'fro')
                
        # Compute gradients
        trainable_vars = self.encoder.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Compute our own metrics
        loss_tracker.update_state(loss)
        
        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {"loss": loss_tracker.result()}
        
    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        # If you don't implement this property, you have to call
        # `reset_states()` yourself at the time of your choosing.
        return [loss_tracker]
    
    def test_step(self, inputs):       
        X = inputs[0]
        Y = inputs[1] 

        Psi_X    = self.encoder(X)
        Psi_Y    = self.encoder(Y)    

        # loss:        
        loss   = tf.norm(Psi_Y - Psi_X, axis = [0, 1], ord = 'fro')

        # Compute our own metrics
        loss_tracker.update_state(loss)
        
        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {"loss": loss_tracker.result()}
        
class Encoder(tf.keras.Model):
    def __init__(self):
        super(Encoder, self).__init__(dtype = 'float64', name = 'Encoder')
        self.input_layer   = DenseLayer(128)
        self.hidden_layer1 = DenseLayer(128)
        self.hidden_layer2 = DenseLayer(64)        
        self.hidden_layer3 = DenseLayer(64)
        self.output_layer  = LinearLayer(64)
        
    def call(self, input_data, training):
        fx = self.input_layer(input_data)        
        fx = self.hidden_layer1(fx)
        fx = self.hidden_layer2(fx)
        fx = self.hidden_layer3(fx)
        return self.output_layer(fx)    

class LinearLayer(tf.keras.layers.Layer):
    def __init__(self, units):
        super(LinearLayer, self).__init__(dtype = 'float64')
        self.units = units

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self.w = self.add_weight(shape = (input_dim, self.units), 
                             initializer = "random_normal", 
                             trainable = True)
        self.b = self.add_weight(shape = (self.units,),    
                             initializer = tf.zeros_initializer(),
                             trainable = True)

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

class DenseLayer(tf.keras.layers.Layer):
    def __init__(self, units):
        super(DenseLayer, self).__init__(dtype = 'float64')
        self.units = units
    
    def build(self, input_shape):
        input_dim = input_shape[-1]
        self.w = self.add_weight(shape = (input_dim, self.units), 
                             initializer = "random_normal", 
                             trainable = True)
        self.b = self.add_weight(shape = (self.units,),    
                             initializer = tf.zeros_initializer(),
                             trainable = True)

    def call(self, inputs):
        x = tf.matmul(inputs, self.w) + self.b
        return tf.nn.elu(x)

推荐答案

我也想知道这个问题.原来 -2 是 tf.data.UNKNOWN_CARDINALITY (https://www.tensorflow.org/api_docs/python/tf/data#UNKNOWN_CARDINALITY),表示 TF 不知道 flat_map 每个项目返回多少元素.

I was wondering about this as well. Turns out that -2 is tf.data.UNKNOWN_CARDINALITY (https://www.tensorflow.org/api_docs/python/tf/data#UNKNOWN_CARDINALITY), which represents that TF doesn't know how many elements the flat_map returns per item.

我刚刚问了在不丢失基数信息的情况下打开 TensorFlow 数据集? 看看是否有人知道一种在不丢失基数的情况下对数据集进行窗口化的方法.

I just asked Windowing a TensorFlow dataset without losing cardinality information? to see if anyone knows a way to window datasets without losing cardinality.

这篇关于Tensorflow 2.0:flat_map() 扁平化数据集的数据集返回基数 -2的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆