pytorch深度Q网络

news/2025/2/1 22:41:19 标签: pytorch, 人工智能, python

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

DQN 引入了深度神经网络来近似Q函数,解决了传统Q-learning在处理高维状态空间时的瓶颈,尤其是在像 Atari 游戏这样的复杂环境中。DQN的核心思想是使用神经网络 Q(s,a;θ)Q(s, a; \theta)Q(s,a;θ) 来近似 Q 值函数,其中 θ\thetaθ 是神经网络的参数。

DQN 的关键创新包括:

  1. 经验回放(Experience Replay):在强化学习中,当前的学习可能会依赖于最近的经验,容易导致学习过程的不稳定。经验回放通过将智能体的经历存储到一个回放池中,然后随机抽取批量数据进行训练,这样可以打破数据之间的相关性,使得训练更加稳定。

  2. 目标网络(Target Network):在Q-learning中,Q值的更新依赖于下一个状态的最大Q值。为了避免Q值更新时过度依赖当前网络的输出(导致不稳定),DQN引入了目标网络。目标网络的结构与行为网络相同,但它的参数更新频率较低,这使得Q值更新更加稳定。

DQN算法流程

  1. 初始化Q网络:初始化Q网络的参数 θ\thetaθ,以及目标网络的参数 θ−\theta^-θ−(通常与Q网络相同)。
  2. 行为选择:基于当前的Q网络来选择动作(通常使用ε-greedy策略,即以ε的概率选择随机动作,否则选择当前Q值最大的动作)。
  3. 执行动作并存储经验:执行所选动作,观察奖励,并记录状态转移 (st,at,rt+1,st+1)(s_t, a_t, r_{t+1}, s_{t+1})(st​,at​,rt+1​,st+1​)。
  4. 经验回放:从回放池中随机抽取一个小批量的经验数据。
  5. 计算Q值目标:对于每个样本,计算目标值 y=rt+1+γmax⁡a′Q(st+1,a′;θ−)y = r_{t+1} + \gamma \max_{a'} Q(s_{t+1}, a'; \theta^-)y=rt+1​+γmaxa′​Q(st+1​,a′;θ−)。
  6. 更新Q网络:通过最小化损失函数 L(θ)=1N∑(y−Q(st,at;θ))2L(\theta) = \frac{1}{N} \sum (y - Q(s_t, a_t; \theta))^2L(θ)=N1​∑(y−Q(st​,at​;θ))2 来更新Q网络的参数。
  7. 周期性更新目标网络:每隔一段时间,将Q网络的参数复制到目标网络。

DQN的应用

DQN在多个领域取得了重要应用,尤其是在强化学习任务中:

  • Atari 游戏:DQN 在多个经典的 Atari 游戏上成功展示了其能力,比如《Breakout》和《Pong》等。
  • 机器人控制:利用DQN,机器人可以在复杂的环境中自主学习如何执行任务。
  • 自动驾驶:在自动驾驶领域,DQN可以用来训练智能体通过道路、避开障碍物等。

例子:

这里我们手动实现一个非常简单的环境:一个1D平衡问题,类似于一个可以左右移动的棒球,目标是让它保持在某个位置上。

import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt


# 自定义环境
class SimpleEnv:
    def __init__(self):
        self.state = 0.0  # 初始状态
        self.goal = 10.0  # 目标位置
        self.done = False

    def reset(self):
        self.state = 0.0
        self.done = False
        return self.state

    def step(self, action):
        if self.done:
            return self.state, 0, self.done  # 游戏结束,不再变化

        # 通过动作修改状态
        self.state += action  # 动作是 -1、0、1,控制移动方向
        reward = -abs(self.state - self.goal)  # 奖励是距离目标位置的负值

        # 如果距离目标很近,就结束
        if abs(self.state - self.goal) < 0.1:
            self.done = True
            reward = 10  # 达到目标时奖励较高

        return self.state, reward, self.done


# Q网络定义
class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc = nn.Linear(input_dim, 24)
        self.fc2 = nn.Linear(24, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc(x))
        x = self.fc2(x)
        return x


# DQN智能体
class DQN:
    def __init__(self, env, gamma=0.99, epsilon=0.1, batch_size=32, learning_rate=1e-3):
        self.env = env
        self.gamma = gamma
        self.epsilon = epsilon
        self.batch_size = batch_size
        self.learning_rate = learning_rate

        self.input_dim = 1  # 因为环境状态是一个单一的数值
        self.output_dim = 3  # 动作空间大小:-1, 0, 1

        self.q_network = QNetwork(self.input_dim, self.output_dim)
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.learning_rate)
        self.criterion = nn.MSELoss()

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.choice([-1, 0, 1])  # 随机选择动作
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            q_values = self.q_network(state)
        # 将动作值 -1, 0, 1 转换为索引 0, 1, 2
        action_idx = torch.argmax(q_values, dim=1).item()
        action_map = [-1, 0, 1]  # -1 -> 0, 0 -> 1, 1 -> 2
        return action_map[action_idx]

    def update(self, state, action, reward, next_state, done):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        next_state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)
        # 将动作 -1, 0, 1 转换为索引 0, 1, 2
        action_map = [-1, 0, 1]
        action_idx = action_map.index(action)
        action = torch.tensor(action_idx, dtype=torch.long).unsqueeze(0)
        reward = torch.tensor(reward, dtype=torch.float32).unsqueeze(0)

        # 确保done是Python标准bool类型
        done = torch.tensor(done, dtype=torch.float32).unsqueeze(0)

        # 计算目标Q值
        with torch.no_grad():
            next_q_values = self.q_network(next_state)
            next_q_value = next_q_values.max(1)[0]
            target_q_value = reward + self.gamma * next_q_value * (1 - done)

        # 获取当前Q值
        q_values = self.q_network(state)
        action_q_values = q_values.gather(1, action.unsqueeze(1)).squeeze(1)

        # 计算损失并更新Q网络
        loss = self.criterion(action_q_values, target_q_value)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def train(self, num_episodes=200):
        rewards = []
        best_reward = -float('inf')  # 初始最好的奖励设为负无穷
        best_episode = 0

        for episode in range(num_episodes):
            state = self.env.reset()  # 获取初始状态
            total_reward = 0
            done = False
            while not done:
                action = self.select_action([state])
                next_state, reward, done = self.env.step(action)
                total_reward += reward

                # 更新Q网络
                self.update([state], action, reward, [next_state], done)

                state = next_state

            rewards.append(total_reward)
            # 记录最佳奖励和对应的episode
            if total_reward > best_reward:
                best_reward = total_reward
                best_episode = episode

            print(f"Episode {episode}, Total Reward: {total_reward}")

        # 打印最佳结果
        print(f"Best Reward: {best_reward} at Episode {best_episode}")

        # 绘制奖励图
        plt.plot(rewards)
        plt.title('Total Rewards per Episode')
        plt.xlabel('Episode')
        plt.ylabel('Total Reward')

        # 在最佳位置添加标记
        plt.scatter(best_episode, best_reward, color='red', label=f"Best Reward at Episode {best_episode}")
        plt.legend()
        plt.show()


# 初始化环境和DQN智能体
env = SimpleEnv()
dqn = DQN(env)

# 训练智能体
dqn.train()


http://www.niftyadmin.cn/n/5839617.html

相关文章

2025数学建模美赛|赛题翻译|E题

2025数学建模美赛&#xff0c;E题赛题翻译 更多美赛内容持续更新中...

双写+灰度发布:高并发场景下的维度表拆分零事故迁移实践

目录 0 文章摘要 1业务场景描述 2 迁移及实施过程 2.1 拆分设计与数据探查 2.1 历史数据迁移(全量)

仿真设计|基于51单片机的温度与烟雾报警系统

目录 具体实现功能 设计介绍 51单片机简介 资料内容 仿真实现&#xff08;protues8.7&#xff09; 程序&#xff08;Keil5&#xff09; 全部内容 资料获取 具体实现功能 &#xff08;1&#xff09;LCD1602实时监测及显示温度值和烟雾浓度值&#xff1b; &#xff08;2…

langgraph实现 handsoff between agents 模式 (1)

官网示例代码 from typing_extensions import Literal from langchain_core.messages import ToolMessage from langchain_core.tools import tool from langgraph.graph import MessagesState, StateGraph, START from langgraph.types import Command from langchain_openai…

【ubuntu】双系统ubuntu下一键切换到Windows

ubuntu下一键切换到Windows 1.4.1 重启脚本1.4.2 快捷方式1.4.3 移动快捷方式到系统目录 按前文所述文档&#xff0c;开机默认启动ubuntu。Windows切换到Ubuntu直接重启就行了&#xff0c;而Ubuntu切换到Windows稍微有点麻烦。可编辑切换重启到Windows的快捷方式。 1.4.1 重启…

吉首市城区地图政府附近1公里范围高清矢量pdf\cdr\ai内容测评

吉首市城区地图以市政府中心附近1公里范围高清矢量pdf\cdr\ai(2021年详细&#xff09;&#xff0c;可以用cdr&#xff0c;ai软件打开编辑文字内容&#xff0c;放大。

WordPress eventon-lite插件存在未授权信息泄露漏洞(CVE-2024-0235)

免责声明: 本文旨在提供有关特定漏洞的深入信息,帮助用户充分了解潜在的安全风险。发布此信息的目的在于提升网络安全意识和推动技术进步,未经授权访问系统、网络或应用程序,可能会导致法律责任或严重后果。因此,作者不对读者基于本文内容所采取的任何行为承担责任。读者在…

CAP 定理的 P 是什么

分布式系统 CAP 定理 P 代表什么含义 作者之前在看 CAP 定理时抱有很大的疑惑&#xff0c;CAP 定理的定义是指在分布式系统中三者只能满足其二&#xff0c;也就是存在分布式 CA 系统的。作者在网络上查阅了很多关于 CAP 文章&#xff0c;虽然这些文章对于 P 的解释五花八门&am…