Hydra进阶——继承,复用与实例化
摘要
本文介绍 Hydra 的配置文件继承、复用和实例化技巧,包括组内和组间继承、配置组管理、组的拆分以及类对象的实例化和部分实例化。[本摘要由小猫AI生成]
引言
在上一节内容中,我们介绍了Hydra
的基本功能和用法。本节内容中,将介绍Hydra
的进阶使用技巧,包含配置文件的继承、复用和实例化等内容。
附:
继承
Hydra
为配置文件提供了继承的功能。我们可以在一个基配置文件的基础上,对原有配置参数进行修改和新增,从而形成一个新的配置文件。
- 组内继承
首先,创建项目环境,项目结构如下所示:
D:.
│ config.yaml
│ run.py
│ tree.txt
│
├─models
│ model_config_base.yaml
│ model_config_modified.yaml
│ server_config_modified.yaml
│
└─server
server_config.yaml
其中,run.py
代码如下所示:
from omegaconf import OmegaConf, DictConfig
import hydra
@hydra.main(version_base=None, config_path=".", config_name="config")
def run(cfg: DictConfig) -> None:
OmegaConf.resolve(cfg)
print(OmegaConf.to_yaml(cfg))
if __name__ == "__main__":
run()
model_config_base.yaml
为models
组内的基配置文件,它的内容如下:
env:
name: model_config_base
version: 1.0
args:
lr: 0.1
epoch: ???
batchsize: ???
当我们希望在model_config_modified.yaml
文件中,继承基配置文件的参数,并在此基础上进行参数修改和新增,可以使用defaults
关键字,指定基配置文件的路径名。在这个例子中,我们分别对基配置文件中的env.name
进行修改,对args.epoch
和args.batchsize
进行赋值,并新增env.charset
参数,model_config_modified.yaml
内容如下所示:
defaults:
- model_config_base
env:
name: model_config_modified
charset: utf-8
args:
epoch: 20
batchsize: 64
在config.yaml
中指定需要运行的默认配置文件:
defaults:
- models: model_config_modified
运行run.py
,可以得到以下结果:
models:
env:
name: model_config_modified
version: 1.0
charset: utf-8
args:
lr: 0.1
epoch: 20
batchsize: 64
可以看到,这样的运行结果不仅继承了基配置文件的各项参数,还对子配置文件所修改、赋值和新增的各项参数执行了对应的操作。通过这种方式,我们不仅可以更加灵活地复用配置文件并指定每次运行的不同参数,还可以防止基配置文件的参数被覆盖而丢失。
- 组间继承
假设我们需要在一个组下继承另一个组中的配置文件,需要对defaults
关键词的内容进行修改,格式为组名/文件名@_here_
。
在这个例子中,我们希望在models
文件夹下继承server
组中的server_config.yaml
文件,并生成新的子配置文件server_config_modified.yaml
。原server_config.yaml
的内容为:
server:
port: 80
host: localhost
username: ye
此时,我们对其中的username
参数进行修改,则server_config_modified.yaml
文件内容应为:
defaults:
- /server/server_config@_here_
username: hello
修改此时的config.yaml
为:
defaults:
- models: server_config_modified
运行run.py
,得到以下结果:
models:
server:
port: 80
host: localhost
username: hello
这样,我们就完成了组间配置文件的继承。
如果在子配置文件的defaults
关键词中不添加@_here_
,则修改参数将会被独立在目标组外。修改server_config_modified.yaml
为:
defaults:
- /server/server_config
username: hello
再次运行run.py
,得到如下结果:
models:
server:
server:
port: 80
host: localhost
username: ye
username: hello
配置组
在config.yaml
文件中,我们设定了每个组下的默认配置文件。然而,在实际的实验过程中,我们可能需要针对不同的环境和条件使用不同的实验组合,对于每个不同的配置组,逐一修改config.yaml
中的默认文件是非常耗时和难以维护的。因此,我们可以为这些不同配置文件的组合设定单独的配置组,从而便于我们的使用和管理。
新建文件夹experiment
,并创建两个不同环境下的配置组文件common.yaml
和run_on_server.yaml
,默认的config.yaml
文件内容如下:
defaults:
- models: server_config_modified
- server: server_config
直接运行run.py
,可以得到以下结果:
models:
server:
port: 80
host: localhost
username: hello
server:
server:
port: 80
host: localhost
username: ye
为了设定一个配置组,我们需要在YAML
文件中使用# @package _global_
注解,同时对需要修改的默认配置文件路径进行修改。将common.yaml
修改为:
# @package _global_
defaults:
- override /models: model_config_base
使用命令行运行python run.py +experiment=common
,结果如下:
models:
env:
name: model_config_base
version: 1.0
args:
lr: 0.1
epoch: ???
batchsize: ???
server:
server:
port: 80
host: localhost
username: ye
可以看到,通过这种方式,我们成功修改了models
组内的运行文件,而server
组内的默认运行文件保持不变。此外,我们还可以对参数进行修改,将run_on_server.yaml
文件修改如下:
# @package _global_
defaults:
- override /server: server_config
- override /models: model_config_modified
server:
username: experiment
执行命令行python run.py +experiment=run_on_server
,运行结果如下:
models:
env:
name: model_config_modified
version: 1.0
charset: utf-8
args:
lr: 0.1
epoch: 20
batchsize: 64
server:
server:
port: 80
host: localhost
username: experiment
可以看到,通过这种方式,我们同时指定了两组需要修改的默认配置文件,同时将server
组中的username
参数进行了修改。
组的拆分
在同一个组内,无法同时将两个或多个配置文件设定为默认运行文件。如我们将config.yaml
修改为如下形式,则程序会引发异常:
defaults:
- models: model_config_modified
- models: server_config_modified
运行结果:
Multiple values for models. To override a value use 'override models: server_config_modified'
因此,当一个组内同时出现了两个或多个需要运行的配置文件时,我们需要将其拆分到两个不同的新组,从而实现这个操作。为了实现这个过程,我们需要在config.yaml
文件中使用原组名@新组名: 文件路径
的方式来进行组的拆分。在这个例子中,我们作出如下修改:
defaults:
- models@args: model_config_modified
- models@remote: server_config_modified
运行结果如下:
args:
env:
name: model_config_modified
version: 1.0
charset: utf-8
args:
lr: 0.1
epoch: 20
batchsize: 64
remote:
server:
port: 80
host: localhost
username: hello
配置实例化
- 实例化
如果我们希望将配置参数为代码中的类对象实例化,可以在配置文件中使用_target_
关键字指定类名,将成员变量以参数的形式在配置文件中指定,同时配合python
函数instatiate
对类进行实例化。下面,我们将给出一个例子。
在使用实例化之前,我们需要在python
代码中引入对应的库函数。
from hydra.utils import instantiate
在run.py
文件中,我们创建一个Model
类,并希望通过配置文件config.yaml
对该类进行实例化。run.py
代码如下:
from omegaconf import OmegaConf, DictConfig
import hydra
from hydra.utils import instantiate
class Model:
def __init__(self, epoch: int, batchsize: int, lr: float) -> None:
self.epoch = epoch
self.batchsize = batchsize
self.lr = lr
def print_args(self):
print(f"epoch = {self.epoch}, batchsize = {self.batchsize}, lr = {self.lr}")
@hydra.main(version_base=None, config_path=".", config_name="config")
def run(cfg: DictConfig) -> None:
model = instantiate(cfg)
model.print_args()
if __name__ == "__main__":
run()
在config.yaml
文件中,我们需要将_target_
参数设定为run.Model
,文件内容如下:
_target_: run.Model
lr: 0.1
epoch: 20
batchsize: 32
运行run.py
,结果如下:
epoch = 20, batchsize = 32, lr = 0.1
- 部分实例化
如果我们需要对某个类对象的部分参数进行实例化,可以在config.yaml
文件中将_partial_
关键词置为true
。在下面一个实例中,我们将对Model
的初始化传入Optimizer
类对象,同时只对该类对象的lr
参数进行配置文件的赋值。run.py
的代码如下所示:
from omegaconf import OmegaConf, DictConfig
import hydra
from hydra.utils import instantiate
class Optimizer:
def __init__(self, lr: float, name: str) -> None:
self.lr = lr
self.name = name
class Model:
def __init__(self, epoch: int, batchsize: int, opt: Optimizer) -> None:
self.epoch = epoch
self.batchsize = batchsize
self.opt = opt(name = "default")
def print_args(self):
print(f"epoch = {self.epoch}, batchsize = {self.batchsize}, optimizer_name = {self.opt.name}, lr = {self.opt.lr}")
@hydra.main(version_base=None, config_path=".", config_name="config")
def run(cfg: DictConfig) -> None:
model = instantiate(cfg)
model.print_args()
if __name__ == "__main__":
run()
此时,应将config.yaml
修改为:
_target_: run.Model
epoch: 20
batchsize: 32
opt:
_partial_: true
_target_: run.Optimizer
lr: 0.1
运行结果:
epoch = 20, batchsize = 32, optimizer_name = default, lr = 0.1