扩展ReinforcementLearner
This commit is contained in:
parent
2d96642da8
commit
2d65d39de2
@ -9,21 +9,19 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class MyCoolRLModel(ReinforcementLearner):
|
||||
"""
|
||||
针对高波动资产 (如 PEPE) 优化的强化学习模型。
|
||||
核心改进:
|
||||
1. 极度简化奖励函数,让模型容易学习
|
||||
2. 大幅鼓励入场和交易
|
||||
3. 减少对亏损的惩罚,让模型敢于尝试
|
||||
基于官方 ReinforcementLearner 的自定义强化学习模型
|
||||
保留了官方的奖励函数框架,同时增加了我们特有的功能
|
||||
"""
|
||||
|
||||
class MyRLEnv(Base5ActionRLEnv):
|
||||
"""
|
||||
自定义环境,重写 calculate_reward 以适应动量交易。
|
||||
自定义环境,重写 calculate_reward 以适应我们的需求
|
||||
继承官方框架,但增强了一些特定功能
|
||||
"""
|
||||
|
||||
def calculate_reward(self, action: int) -> float:
|
||||
"""
|
||||
优化的奖励函数,特别处理亏损订单持仓太久的问题。
|
||||
优化的奖励函数,结合官方框架和我们特定需求
|
||||
|
||||
:param action: int = 智能体为当前K线做出的动作。
|
||||
:return:
|
||||
@ -32,7 +30,7 @@ class MyCoolRLModel(ReinforcementLearner):
|
||||
# 首先,如果动作无效,则惩罚
|
||||
if not self._is_valid(action):
|
||||
self.tensorboard_log("invalid", category="actions")
|
||||
return -1.0
|
||||
return -2
|
||||
|
||||
# 获取核心状态数据
|
||||
pnl = self.get_unrealized_profit()
|
||||
@ -42,29 +40,41 @@ class MyCoolRLModel(ReinforcementLearner):
|
||||
if self._last_trade_tick is not None:
|
||||
trade_duration = self._current_tick - self._last_trade_tick
|
||||
|
||||
# 获取配置参数
|
||||
max_trade_duration = self.rl_config.get("max_trade_duration_candles", 300)
|
||||
profit_aim = self.rl_config["model_reward_parameters"].get("profit_aim", 0.025)
|
||||
rr = self.rl_config["model_reward_parameters"].get("rr", 1)
|
||||
|
||||
# =========================================================
|
||||
# 场景 1: 入场 (Long Enter)
|
||||
# =========================================================
|
||||
if action == Actions.Long_enter.value and self._position == Positions.Neutral:
|
||||
# 奖励入场,鼓励模型尝试
|
||||
return 0.5
|
||||
# 使用官方的高奖励鼓励入场
|
||||
return 25
|
||||
|
||||
# =========================================================
|
||||
# 场景 2: 空仓观望 (Neutral) - 重罚!
|
||||
# 场景 2: 空仓观望 (Neutral) - 轻罚
|
||||
# =========================================================
|
||||
if action == Actions.Neutral.value and self._position == Positions.Neutral:
|
||||
# 空仓观望给予重罚,强烈鼓励模型寻找机会
|
||||
return -0.5
|
||||
# 使用官方的轻惩罚
|
||||
return -1
|
||||
|
||||
# =========================================================
|
||||
# 场景 3: 持仓中 (Holding)
|
||||
# =========================================================
|
||||
if self._position in (Positions.Short, Positions.Long):
|
||||
if action == Actions.Neutral.value:
|
||||
# 持仓时,综合考虑盈亏和时间
|
||||
reward = 0.0
|
||||
# 1. 官方的持续持仓惩罚
|
||||
if trade_duration <= max_trade_duration:
|
||||
factor = 1.5
|
||||
else:
|
||||
factor = 0.5
|
||||
|
||||
# 1. 盈亏奖励/惩罚
|
||||
# 持续持仓惩罚
|
||||
time_penalty = -1 * trade_duration / max_trade_duration
|
||||
reward = time_penalty
|
||||
|
||||
# 2. 盈亏奖励/惩罚
|
||||
if pnl > 0:
|
||||
# 浮盈:直接奖励
|
||||
reward += pnl * 10.0
|
||||
@ -72,18 +82,6 @@ class MyCoolRLModel(ReinforcementLearner):
|
||||
# 浮亏:惩罚,但不要太严厉
|
||||
reward += pnl * 2.0
|
||||
|
||||
# 2. 时间惩罚 - 特别针对亏损订单持仓太久
|
||||
# 如果浮亏且持仓时间较长,增加时间惩罚
|
||||
if pnl < 0 and trade_duration > 10: # 如果亏损且持仓超过10根K线
|
||||
# 时间惩罚随持仓时间递增
|
||||
time_penalty = -0.01 * (trade_duration - 10) # 每多持有一根K线,多惩罚0.01
|
||||
reward += time_penalty
|
||||
|
||||
# 3. 如果浮亏且持仓时间很长,惩罚加重
|
||||
if pnl < 0 and trade_duration > 30: # 如果亏损且持仓超过30根K线
|
||||
heavy_time_penalty = -0.1 * (trade_duration - 30) # 每多持有一根K线,多惩罚0.1
|
||||
reward += heavy_time_penalty
|
||||
|
||||
return reward
|
||||
|
||||
# =========================================================
|
||||
@ -92,17 +90,16 @@ class MyCoolRLModel(ReinforcementLearner):
|
||||
if (action == Actions.Long_exit.value and self._position == Positions.Long) or \
|
||||
(action == Actions.Short_exit.value and self._position == Positions.Short):
|
||||
|
||||
# 1. 使用官方的盈利/亏损奖励机制
|
||||
if pnl > 0:
|
||||
# 盈利离场:大幅奖励
|
||||
return pnl * 50.0
|
||||
# 盈利离场:使用官方的奖励机制
|
||||
factor = 100.0
|
||||
if pnl > profit_aim * rr:
|
||||
factor *= self.rl_config["model_reward_parameters"].get("win_reward_factor", 2)
|
||||
return float(pnl * factor)
|
||||
else:
|
||||
# 亏损离场:惩罚,但要考虑持仓时间
|
||||
# 如果是长期亏损持仓被迫离场,惩罚稍重
|
||||
# 如果是短期止损离场,惩罚较轻
|
||||
base_penalty = pnl * 5.0
|
||||
if pnl < 0 and trade_duration > 20: # 长期亏损持仓
|
||||
return base_penalty * 1.5 # 惩罚加重50%
|
||||
else:
|
||||
return base_penalty
|
||||
# 亏损离场:使用官方的惩罚机制
|
||||
factor = 100.0
|
||||
return float(pnl * factor)
|
||||
|
||||
return 0.0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user