Twin Delayed DDPG (TD3)

Twin Delayed DDPG (TD3)

一月 10, 2021

虽然DDPG有时可以获得很好的性能,但对于超参数和其他类型的调优,它经常是脆弱的。DDPG最常见的失效模式是学习后的q函数开始大幅高估q值,从而导致策略失效,因为它利用了q函数中的错误。Twin Delayed DDPG (TD3)是一种通过引入三个关键技巧来解决这个问题的算法:

  • Clipped Double Q-learning: TD3学习两个q函数而不是一个(因此称为“Twin”),并使用两个q值中较小的一个来形成Bellman误差损失函数中的目标。
  • “Delayed Policy Update: TD3更新策略(和目标网络)的频率低于q函数。比如,每两个q函数更新进行一次策略更新。
  • Target Policy Smoothing: TD3向目标动作添加噪声,使策略更难利用Q函数错误,方法是使Q沿着动作的变化平滑。

Key Equations

Target Policy Smoothing

用于形成q学习目标的动作是基于目标策略,$\mu{\theta{\text{targ}}} $,但是在动作的每个维度上都添加了剪切噪声。在添加了被剪辑的噪声之后,目标动作就会被剪辑到有效的动作范围内(所有有效的动作$a$,满足$a{Low} \leq a \leq a{High}$)。目标操作如下:

目标策略平滑实质上是算法的正则化。它解决了DDPG中可能发生的特定失效模式:如果q函数近似器为某些动作开发了一个不正确的尖峰,策略将迅速利用该尖峰,然后产生脆弱或不正确的行为。这可以通过平滑类似行为的q函数来避免,这是政策平滑的目标。

Clipped Double Q-learning

两个q函数都使用一个目标,使用两个q函数中的任意一个计算出一个较小的目标值:

然后他们都通过回归这个目标来学习:

为目标使用较小的q值,并向其回归,有助于避免q函数中的过高估计。

最后:通过最大化$Q_{\phi_1}$来学习策略:

这和DDPG几乎没有什么区别。然而,在TD3中,策略的更新频率低于q函数。这有助于抑制DDPG中由于策略更新更改目标的方式而出现的波动性。

算法

Algorithm

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(Actor, self).__init__()
self.f1 = nn.Linear(state_dim, 256)
self.f2 = nn.Linear(256, 128)
self.f3 = nn.Linear(128, action_dim)
self.max_action = max_action
def forward(self,x):
x = self.f1(x)
x = F.relu(x)
x = self.f2(x)
x = F.relu(x)
x = self.f3(x)
return torch.tanh(x) * self.max_action
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic,self).__init__()
self.f11 = nn.Linear(state_dim+action_dim, 256)
self.f12 = nn.Linear(256, 128)
self.f13 = nn.Linear(128, 1)

self.f21 = nn.Linear(state_dim + action_dim, 256)
self.f22 = nn.Linear(256, 128)
self.f23 = nn.Linear(128, 1)

def forward(self, state, action):
sa = torch.cat([state, action], 1)

x = self.f11(sa)
x = F.relu(x)
x = self.f12(x)
x = F.relu(x)
Q1 = self.f13(x)

x = self.f21(sa)
x = F.relu(x)
x = self.f22(x)
x = F.relu(x)
Q2 = self.f23(x)

return Q1, Q2


self.actor = Actor(self.state_dim, self.action_dim, self.max_action)
self.target_actor = copy.deepcopy(self.actor)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)


self.critic = Critic(self.state_dim, self.action_dim)
self.target_critic = copy.deepcopy(self.critic)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def learn(self):
self.total_it += 1
data = self.buffer.smaple(size=128)
state, action, done, state_next, reward = data
with torch.no_grad:
noise = (torch.rand_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
next_action = (self.target_actor(state_next) + noise).clamp(-self.max_action, self.max_action)
target_Q1,target_Q2 = self.target_critic(state_next, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + done * self.discount * target_Q
current_Q1, current_Q2 = self.critic(state, action)
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
critic_loss.backward()
self.critic_optimizer.step()

if self.total_it % self.policy_freq == 0:

q1,q2 = self.critic(state, self.actor(state))
actor_loss = -torch.min(q1, q2).mean()

self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
for param, target_param in zip(self.critic.parameters(), self.target_critic.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

for param, target_param in zip(self.actor.parameters(), self.target_actor.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)