Kedro-如何将嵌套参数直接传递给节点 [英] Kedro - how to pass nested parameters directly to node

查看:53
本文介绍了Kedro-如何将嵌套参数直接传递给节点的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

kedro 建议将参数存储在 conf/base/parameters.yml .假设它看起来像这样:

kedro recommends storing parameters in conf/base/parameters.yml. Let's assume it looks like this:

step_size: 1
model_params:
    learning_rate: 0.01
    test_data_ratio: 0.2
    num_train_steps: 10000

现在想象一下我有一些 data_engineering 管道,其 nodes.py 具有看起来像这样的功能:

And now imagine I have some data_engineering pipeline whose nodes.py has a function that looks something like this:

def some_pipeline_step(num_train_steps):
    """
    Takes the parameter `num_train_steps` as argument.
    """
    pass

我该如何继续将嵌套的参数直接传递给 data_engineering/pipeline.py 中的该函数?我尝试失败:

How would I go about and pass that nested parameters straight to this function in data_engineering/pipeline.py? I unsuccessfully tried:

from kedro.pipeline import Pipeline, node

from .nodes import split_data


def create_pipeline(**kwargs):
    return Pipeline(
        [
            node(
                some_pipeline_step,
                ["params:model_params.num_train_steps"],
                dict(
                    train_x="train_x",
                    train_y="train_y",
                ),
            )
        ]
    )

我知道我可以通过使用 ['parameters'] 将所有参数传递到函数中,或者仅将所有 model_params 参数与 ['params]传递给函数:model_params'] ,但似乎没什么问题,我觉得一定有办法.将不胜感激!

I know that I could just pass all parameters into the function by using ['parameters'] or just pass all model_params parameters with ['params:model_params'] but it seems unelegant and I feel like there must be a way. Would appreciate any input!

推荐答案

(免责声明:我是Kedro团队的成员)

(Disclaimer: I'm part of the Kedro team)

谢谢您的提问.不幸的是,当前版本的Kedro不支持嵌套参数.临时解决方案是在节点内部使用顶级键(如您已经指出的那样),或者使用某种参数过滤器修饰节点功能,这也不是很完美.

Thank you for your question. Current version of Kedro, unfortunately, does not support nested parameters. The interim solution would be to use top-level keys inside the node (as you already pointed out) or decorate your node function with some sort of a parameter filter, which is not elegant either.

最可行的解决方案可能是通过覆盖 _get_feed_dict (在 src//run.py 中)类来自定义该类.code>方法如下:

Probably the most viable solution would be to customise your ProjectContext (in src/<package_name>/run.py) class by overwriting _get_feed_dict method as follows:

class ProjectContext(KedroContext):
    # ...


    def _get_feed_dict(self) -> Dict[str, Any]:
        """Get parameters and return the feed dictionary."""
        params = self.params
        feed_dict = {"parameters": params}

        def _add_param_to_feed_dict(param_name, param_value):
            """This recursively adds parameter paths to the `feed_dict`,
            whenever `param_value` is a dictionary itself, so that users can
            specify specific nested parameters in their node inputs.

            Example:

                >>> param_name = "a"
                >>> param_value = {"b": 1}
                >>> _add_param_to_feed_dict(param_name, param_value)
                >>> assert feed_dict["params:a"] == {"b": 1}
                >>> assert feed_dict["params:a.b"] == 1
            """
            key = "params:{}".format(param_name)
            feed_dict[key] = param_value

            if isinstance(param_value, dict):
                for key, val in param_value.items():
                    _add_param_to_feed_dict("{}.{}".format(param_name, key), val)

        for param_name, param_value in params.items():
            _add_param_to_feed_dict(param_name, param_value)

        return feed_dict

也请注意,此问题已经已在开发中解决,并将变为在下一个版本中可用.该修复程序使用上面摘录中的方法.

Please also note that this issue has already been addressed on develop and will become available in the next release. The fix uses the approach from the snippet above.

这篇关于Kedro-如何将嵌套参数直接传递给节点的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆