【多智能体强化学习】Pymarl代码分析
Last updated on February 22, 2024 am
Pymarl代码结构
本文章主要介绍多智能体强化学习中的PyMarl框架的代码结构以及训练流程
Main
Pymarl的主文件(main.py
)主要的作用是构建一个
sacred.Experiment 类的对象 ex
,ex
包含三个重要的内置变量:
_run
:表示当前实验运行时的 run 对象,_run.info
可用于记录实验中产生的结果,实验初始时是空字典{}
;_config
:表示当前实验运行时的参数,字典类,pymarl首先读取配置文件然后利用ex.add_config()
将配置文件中的参数添加到_config
变量中;_log
:一个 logger,pymarl首先创建了一个 logging.logger 类的对象logger
,然后将logger
赋给了ex.logger
也就是_log
。_log
可以通过_log.info('information')
在控制台打印实验过程中的中间信息,方便我们能够定期追踪实验状态。
1 |
|
Run
如何进入:通过借助装饰器定义一个主函数,其中还定义了如何运行
1 |
|
关于这部分run_REGISTRY
字典的设置,在初始化中文件中会进行指定:
1 |
|
这个地方引入的
import run as default_run
就是指定了在某个文件中的run
函数主体,通过最外层的装饰器进行指定运行那个合适的run
函数
run_REGISTRY[_config['run']](_run, config, _log)
表示了指定运行哪个函数,并传递相关的参数
run.py
文件中run
函数的主要作用是构建实验参数变量 args
以及一个自定义 Logger
类的记录器
logger
。内置变量_config
的拷贝作为参数传入到了run
函数中,_config
是字典变量,因此查看参数时,需要利用 _config[key]=value
,在
run
函数中,作者构建了一个namespace
类的变量args
,将_config
中的参数都传给了
args
,这样就可以通过args.key=value
的方式查看参数了。
其中'run'
指的是在参数文件中指定的,比如在
default.yaml
文件中会进行指定字典中的元素
1 |
|
上述代码中创建了一个 Logger
类的实例,用于封装日志记录的功能,用于记录信息,_log.info()
使用的就是用于打印日志信息的语句,其中消息内容哥就是
Experiment Parameters
1 |
|
这一步是主要的实验运行的板块
==run_sequential
==,
借助这个板块对实验运行的内容进行管理和控制,接下来将对这个板块的实现进行详细的介绍。将参数以及初始化创建之后的日志板块都传递进这个函数内部进行训练
1 |
|
结尾部分就是对进程的退出和控制。
Run Sequential
这部分是run的核心,也是对内部的函数进行控制和调用的关键
Mac管理器
主要被扔进 runnner
和 learner
两个板块中使用
属于自定义的controller.basic_controller.BasicMAC
类,该对象的主要作用是
控制智能体,因此mac
对象中的一个重要属性就是nn.module
类的智能体对象mac.agent
,该对象定义了各个智能体的局部Q网络,即接收观测作为输入,输出智能体各个动作的Q值。
1 |
|
mac
对象有两个关键方法:
mac.forward(ep_batch, t, test_mode=False)
:ep_batch
表示一个episode的样本,t
表示每个样本在该episode内的时间索引,forward()
方法的作用是输出一个episode内每个时刻的观测对应的所有动作的Q值与隐层变量mac.hidden_states
。mac.select_actions(self, ep_batch, t_ep, t_env, bs=slice(None), test_mode=False)
:该方法用于在一个episode中每个时刻为所有智能体选择动作。t_ep
代表当前样本在一个episode中的时间索引。t_env
代表当前时刻环境运行的总时间,用于计算epsilon-greedy中的epsilon。
环境运行器
首先初始化环境的运行器,这部分能够获取环境的信息
1 |
|
尝试从环境中获取相关的数据信息,如各种智能体数量、动作空间维度、状态的维度等
其中getattr(args, "accumulated_episodes", None)
的意思是,如果
args
中有"accumulated_episodes"
,那么就获取,否则就设置为
None
一种很优雅的写法,下面的也是同理
1 |
|
这个地方就是定义需要用到的强化学习字典数据结构
1 |
|
初始化运行环境,确保所有必要的组件都已经配置好,以便开始运行或者训练算法。这个方法可能会创建必要的数据结构,设置好数据流,以及准备环境来执行特定的训练或测试循环
1 |
|
runner
对象中最关键的方法是:
runner.run(test_mode=False)
:利用当前智能体mac
在环境中运行(需要用到mac
对象),产生一个episode的样本数据episode_batch
,存储在runner.batch
中。
Buffer存储器
buffer
对象属于自定义的components.episode_buffer.ReplayBuffer(EpisodeBatch)
类,该对象的主要作用是存储样本以及采样样本。ReplayBuffer
的父类是EpisodeBatch
。EpisodeBatch
类对象用于存储episode的样本,ReplayBuffer(EpisodeBatch)
类对象则用于存储所有的off-policy样本,也即EpisodeBatch
类变量的样本会持续地补充到ReplayBuffer(EpisodeBatch)
类的变量中。同样由于QMix用的是DRQN结构,因此EpisodeBatch
与ReplayBuffer
中的样本都是以episode为单位存储的。在EpisodeBatch
中数据的维度是[batch_size, max_seq_length, *shape]
,ReplayBuffer
类数据的维度是[buffer_size, max_seq_length, *shape]
。EpisodeBatch
中Batch Size
表示此时batch中有多少episode,ReplayBuffer
中episodes_in_buffer
表示此时buffer中有多少个episode的有效样本。max_seq_length
则表示一个episode的最大长度。
buffer
对象中的关键方法有:
buffer.insert_episode_batch(ep_batch)
:将EpisodeBatch
类变量ep_batch
中的样本全部存储到buffer
中。buffer.sample(batch_size)
:从buffer
中取出batch_size
个episode的样本用于训练,这些样本组成了EpisodeBatch
类的对象
Learner学习器
自定义的leaners.q_learner.QLearner
(与具体选择哪个算法有关),该对象的主要作用是依据特定算法对智能体参数进行训练更新。在QMix算法中,有nn.module
类的混合网络learner.mixer
,因此learner
对象需要学习的参数包括各个智能体的局部Q网络参数mac.parameters()
1 |
|
其中这个 [args.learner]
参数是根据调用命令的时候传递的config
的参数文件config.yml
中的参数,在内部会一并将其参数统一收集使用
learner.train(batch: EpisodeBatch, t_env: int, episode_num: int)
:batch
表示当前用于训练的样本,t_env
表示当前环境运行的总时间步数,episode_num
表示当前环境运行的总episode数,该方法利用特定算法对learner.params
进行更新
Train过程
收集buffer
1 |
|
with th.no_grad():
在前期的操作中我们只需要收集环境的交互信息,因此此时不需要用到梯度和推理
episode_batch = runner.run(test_mode=False)
这行代码调用了一个名为runner
的对象的方法。这个runner
可能是一个负责执行环境交互和收集经验的模块。test_mode=False
表明这个调用是在训练模式下运行的,而不是测试模式在训练过程中,智能体(agent)会根据策略(policy)在环境中采取行动,收集一系列的剧集(episodes),这些剧集包含了状态(states)、动作(actions)、奖励(rewards)等信息.
buffer.insert_episode_batch(episode_batch)
这行代码将刚刚收集到的剧集批次(episode_batch)插入到一个名为buffer的数据结构中。
1 |
|
这一步是用来判断当前的buffer
的容量是否超过了满足选择batch_size
的大小,如果满足的话那么就可以在
buffer
中开始选择对应的大小
1 |
|
上述的过程是对buffer
中的元素内容进行采样
1 |
|
这一步是在补充时间步骤,在强化学习中,每个episode可能会有不同长度,因为它们可能由于早停或达到某个终止条件而长度不一。
第二步是在从其中sample
之后的数据元素进行截取选择的结果
Train
具体的学习过程就会扔进learner
的学习器中进行训练参数
1 |
|
调用学习器的train
方法来实际进行训练。episode_sample
是用来训练的数据,runner.t_env
是环境步数(可能用于记录或者计算折扣因子等),episode
可能是一个计数器或者记录当前是第几个episode
的变量
注意对于价值函数分解的方法会额外用到mixer
为了具体分析learner
的训练过程,我们这里给出基于
group
的学习代码来进行分析
首先第一步是进行初始化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24def __init__(self, mac, scheme, logger, args):
self.args = args
self.mac = mac
self.target_mac = copy.deepcopy(mac)
self.params = list(self.mac.parameters())
self.logger = logger
self.device = th.device('cuda' if args.use_cuda else 'cpu')
if args.mixer == "group":
self.mixer = GroupMixer(args)
else:
raise "mixer error"
self.target_mixer = copy.deepcopy(self.mixer)
self.params += list(self.mixer.parameters())
#其中这一步是将模型的参数量级进行输出
print('Mixer Size: ')
print(get_parameters_num(self.mixer.parameters()))
if self.args.optimizer == 'adam':
self.optimiser = Adam(params=self.params, lr=args.lr)
else:
self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps)
self.last_target_update_episode = 0
self.log_stats_t = -self.args.learner_log_interval - 1
self.train_t = 0这里会初始化定义一些比如
mac/target mac
和mixer/target mixer
的初始化参数内容,同时定义优化算子:Adam/RMSprop等内容估计所有智能体的Q价值函数值
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16mac_out = []
mac_hidden = []
mac_group_state = []
self.mac.init_hidden(batch.batch_size)
for t in range(batch.max_seq_length):
agent_outs = self.mac.forward(batch, t=t)
mac_hidden.append(self.mac.hidden_states)
mac_group_state.append(self.mac.group_states)
mac_out.append(agent_outs)
mac_out = th.stack(mac_out, dim=1) #得到所有的智能体的输出并堆在一起
mac_hidden = th.stack(mac_hidden, dim=1)
mac_group_state = th.stack(mac_group_state, dim=1)
mac_hidden = mac_hidden.detach() #这一步相当于从计算图中取消不计算梯度学习并得到联合的Q价值函数
1
2
3
4
5# Pick the Q-Values for the actions taken by each agent
chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3)
# Mixer
chosen_action_qvals, w1_avg_list, sd_loss = self.mixer(chosen_action_qvals, batch["state"][:, :-1], mac_hidden[:, :-1], mac_group_state[:, :-1], "eval")这一步中调用
mixer
的函数来学习联合的动作价值函数计算target Q的值
最开始的步骤和Q函数的处理过程一样来学习每个智能体的价值函数
1
2
3
4
5
6
7
8
9
10
11with th.no_grad():
target_mac_out = []
target_mac_hidden = []
target_mac_group_state = []
self.target_mac.init_hidden(batch.batch_size)
for t in range(batch.max_seq_length):
target_agent_outs = self.target_mac.forward(batch, t=t)
target_mac_hidden.append(self.target_mac.hidden_states)
target_mac_group_state.append(self.target_mac.group_states)
target_mac_out.append(target_agent_outs)计算得到目标的联合动作价值函数
1
2# Calculate n-step Q-Learning targets
target_max_qvals, _, _ = self.target_mixer(target_max_qvals, batch["state"], target_mac_hidden, target_mac_group_state, "target")计算TD target
1
2targets = build_td_lambda_targets(rewards, terminated, mask, target_max_qvals,
self.args.n_agents, self.args.gamma, self.args.td_lambda)相当于计算 \(r+\gamma Q_{taregt}\)
计算TD error
1
2td_error = (chosen_action_qvals - targets.detach())
td_error = 0.5 * td_error.pow(2)注意填充的部分
mask = mask.expand_as(td_error)
这一行代码中,mask
它的元素为True
的地方表示对应的序列元素是有效的,而False
的地方表示对应的序列元素是填充的(或不存在的,在序列的开始部分)expand_as
方法会将mask
张量扩展到与td_error
相同的大小,这样做是为了保证在计算masked_td_error
时,mask
能够覆盖td_error
的每个元素。masked_td_error = td_error * mask
在这一行中,结果masked_td_error
只有在mask为True
的位置才会有非零值,而在mask
为False
的位置(即填充的部分)将会是零。td_loss = masked_td_error.sum() / mask.sum()
最后一行代码计算了masked_td_error
的总和,然后除以mask
的总和,这样可以得到一个标量损失值,这个值只包含了有效序列元素的贡献。这样做是为了在计算损失时忽略填充的部分,因为这些部分在训练过程中不应该影响模型的学习.
1
2
3mask = mask.expand_as(td_error)
masked_td_error = td_error * mask
td_loss = masked_td_error.sum() / mask.sum()最后的优化算子的部分反向传播
1
2
3
4self.optimiser.zero_grad()
loss.backward()
grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip)
self.optimiser.step()按照一定的时间步长更新目标网络
1
2
3if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0:
self._update_targets()
self.last_target_update_episode = episode_num根据需要可以选择打印出log
1
2
3
4
5
6
7
8
9
10
11if t_env - self.log_stats_t >= self.args.learner_log_interval:
self.logger.log_stat("loss_td", td_loss.item(), t_env)
self.logger.log_stat("grad_norm", grad_norm, t_env)
mask_elems = mask.sum().item()
self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env)
self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents), t_env)
self.logger.log_stat("target_mean", (targets * mask).sum().item()/(mask_elems * self.args.n_agents), t_env)
self.logger.log_stat("total_loss", loss.item(), t_env)
self.logger.log_stat("lasso_loss", lasso_loss.item(), t_env)
self.logger.log_stat("sd_loss", sd_loss.item(), t_env)
self.log_stats_t = t_env
Test
1 |
|
注意这个地方就是对算法进行测试了