如何在 Python 中并行扫描多个超参数集? [英] How to sweep many hyperparameter sets in parallel in Python?

查看:68
本文介绍了如何在 Python 中并行扫描多个超参数集?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

请注意,我必须扫描比可用 CPU 多的参数集,因此我不确定 Python 是否会根据 CPU 的可用性或其他内容自动安排它们的使用.

Note that I have to sweep through more argument sets than available CPUs, so I'm not sure if Python will automatically schedule the use of the CPUs depending on their availability or what.

这是我尝试过的,但我收到有关参数的错误:

Here is what I tried, but I get an error about the arguments:

import random
import multiprocessing
from train_nodes import run
import itertools

envs = ["AntBulletEnv-v0", "HalfCheetahBulletEnv-vo", "HopperBulletEnv-v0", "ReacherBulletEnv-v0",
        "Walker2DBulletEnv-v0", "InvertedDoublePendulumBulletEnv-v0"]
algs = ["PPO", "A2C"]
seeds = [random.randint(0, 200), random.randint(200, 400), random.randint(400, 600), random.randint(600, 800), random.randint(800, 1000)]

args = list(itertools.product(*[envs, algs, seeds]))

num_cpus = multiprocessing.cpu_count()

with multiprocessing.Pool(num_cpus) as processing_pool:
    processing_pool.map(run, args)

run 接受 3 个参数:env、alg 和 seed.出于某种原因,它没有注册所有 3 个.

run takes in 3 arguments: env, alg, and seed. For some reason here it doesn't register all 3.

推荐答案

multiprocessing.Pool.map 需要一个参数.调整代码的一种方法是编写一个小的包装函数,将 envalgseed 作为一个参数,将它们分开,然后将它们传递给 run.

The function in multiprocessing.Pool.map expects one argument. One way to adapt your code is to write a small wrapper function that takes env, alg, and seed as one argument, separates them, and passes them to run.

另一种选择是使用 multiprocessing.Pool.starmap,允许将多个参数传递给函数.

Another option is to use multiprocessing.Pool.starmap, which allows multiple arguments to be passed to the function.

import random
import multiprocessing
import itertools

envs = [
    "AntBulletEnv-v0",
    "HalfCheetahBulletEnv-vo",
    "HopperBulletEnv-v0",
    "ReacherBulletEnv-v0",
    "Walker2DBulletEnv-v0",
    "InvertedDoublePendulumBulletEnv-v0",
]
algs = ["PPO", "A2C"]
seeds = [
    random.randint(0, 200),
    random.randint(200, 400),
    random.randint(400, 600),
    random.randint(600, 800),
    random.randint(800, 1000),
]

args = list(itertools.product(*[envs, algs, seeds]))

num_cpus = multiprocessing.cpu_count()

# sample implementation or `run`
def run(env, alg, seed):
    # do stuff
    return random.randint(0, 200)

def wrapper(env_alg_seed):
    env, alg, seed = env_alg_seed
    return run(env=env, alg=alg, seed=seed)

# use a wrapper
with multiprocessing.Pool(num_cpus) as processing_pool:
    # accumulate results in a dictionary
    results = processing_pool.map(wrapper, args)

# use starmap and call `run` directly
with multiprocessing.Pool(num_cpus) as processing_pool:
    results = processing_pool.starmap(run, args)

这篇关于如何在 Python 中并行扫描多个超参数集?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆