引入freqai优化

This commit is contained in:
zhangkun9038@dingtalk.com 2025-09-05 00:31:19 +08:00
parent 104bb08f99
commit 00f2e91599

View File

@ -34,7 +34,28 @@ class FreqaiPrimer(IStrategy):
super().__init__(config) # 调用父类的初始化方法并传递config
# 存储从配置文件加载的默认值
self._trailing_stop_positive_default = 0.004 # 降低默认值以更容易触发跟踪止盈
# FreqAI配置
freqai_info = {
"model": "LightGBMRegressor",
"feature_parameters": {
"include_timeframes": ["3m", "15m", "1h"],
"include_corr_pairlist": [],
"label_period_candles": 12,
"include_shifted_candles": 3,
},
"data_split_parameters": {
"test_size": 0.2,
"shuffle": False,
},
"model_training_parameters": {
"n_estimators": 200,
"learning_rate": 0.05,
"num_leaves": 31,
"verbose": -1,
},
}
@property
def protections(self):
"""
@ -107,6 +128,35 @@ class FreqaiPrimer(IStrategy):
if missing_columns:
logger.warning(f"[{metadata['pair']}] 数据框中缺少以下列: {missing_columns}")
def set_freqai_targets(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
logger.info(f"设置 FreqAI 目标,交易对:{metadata['pair']}")
if "close" not in dataframe.columns:
logger.error("数据框缺少必要的 'close'")
raise ValueError("数据框缺少必要的 'close'")
try:
label_period = self.freqai_info["feature_parameters"]["label_period_candles"]
# 创建我们想要预测的目标动态RSI入场阈值
dataframe["&-rsi_entry_threshold"] = np.where(
dataframe['market_state'].isin(['strong_bull', 'weak_bull']), 50, 45
)
# 计算未来收益作为额外目标,帮助模型学习
dataframe["&-future_return"] = dataframe["close"].pct_change(label_period).shift(-label_period)
# 处理可能的NaN值
for col in ["&-rsi_entry_threshold", "&-future_return"]:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], 0)
dataframe[col] = dataframe[col].ffill()
dataframe[col] = dataframe[col].fillna(0)
except Exception as e:
logger.error(f"创建 FreqAI 目标失败:{str(e)}")
raise
logger.info(f"目标列预览:\n{dataframe[["&-rsi_entry_threshold"]].head().to_string()}")
return dataframe
def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
# 计算 3m 周期的指标
bb_3m = ta.bbands(dataframe['close'], length=self.bb_length, std=self.bb_std)
@ -198,10 +248,11 @@ class FreqaiPrimer(IStrategy):
# 确保重新索引时不引入未来数据
df_1h = df_1h.set_index('date').reindex(dataframe['date']).ffill().fillna(method='bfill').reset_index()
df_1h = df_1h.rename(columns={'index': 'date'})
# Include macd_1h and macd_signal_1h in the column selection
# Include macd_1h and macd_signal_1h in the column selection
df_1h = df_1h[['date', 'rsi_1h', 'trend_1h', 'ema_50_1h', 'ema_200_1h', 'bb_lower_1h', 'bb_upper_1h', 'stochrsi_k_1h', 'stochrsi_d_1h', 'macd_1h', 'macd_signal_1h']].ffill()
# Validate that all required columns are present
# Validate that all required columns are present
required_columns = ['date', 'rsi_1h', 'trend_1h', 'ema_50_1h', 'ema_200_1h',
'bb_lower_1h', 'bb_upper_1h', 'stochrsi_k_1h', 'stochrsi_d_1h',
'macd_1h', 'macd_signal_1h']
@ -362,9 +413,27 @@ class FreqaiPrimer(IStrategy):
else:
logger.info(f"[{metadata['pair']}] 数据修复完成,所有行数据均有效")
# 调用FreqAI进行预测
try:
dataframe = self.freqai.start(dataframe, metadata, self)
logger.info(f"[{metadata['pair']}] FreqAI处理完成输出列: {list(dataframe.columns)}")
except Exception as e:
logger.error(f"[{metadata['pair']}] FreqAI处理失败: {str(e)}")
# 如果FreqAI失败继续使用默认参数
# 动态计算每行的市场状态,避免未来数据泄漏
dataframe['current_score'] = dataframe['market_score']
dataframe['current_state'] = dataframe['market_state']
# 如果FreqAI提供了预测的RSI入场阈值使用它否则使用默认值
if '&-rsi_entry_threshold' in dataframe.columns:
dataframe['rsi_entry_threshold'] = dataframe['&-rsi_entry_threshold']
# 确保阈值在合理范围内
dataframe['rsi_entry_threshold'] = dataframe['rsi_entry_threshold'].clip(30, 60)
else:
# 回退到基于市场状态的静态阈值
dataframe['rsi_entry_threshold'] = np.where(
dataframe['market_state'].isin(['strong_bull', 'weak_bull']), 50, 45
)
# 记录最后一条记录的市场状态
if len(dataframe) > 0:
@ -418,75 +487,56 @@ class FreqaiPrimer(IStrategy):
return dataframe
def populate_entry_trend(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
pair = metadata['pair']
current_time = datetime.datetime.now()
# 入场条件
# 基本条件价格接近布林带下轨RSI超卖StochRSI超卖MACD上升趋势
# 使用FreqAI优化的RSI入场阈值
if 'rsi_entry_threshold' in dataframe.columns:
# 使用FreqAI优化的动态阈值
logger.info(f"[{metadata['pair']}] 使用FreqAI优化的RSI入场阈值")
rsi_condition = dataframe['rsi_3m'] < dataframe['rsi_entry_threshold']
else:
# 回退到基于市场状态的静态阈值
logger.info(f"[{metadata['pair']}] 使用基于市场状态的RSI入场阈值")
rsi_condition = np.where(
dataframe['market_state'].isin(['strong_bull', 'weak_bull']),
dataframe['rsi_3m'] < 50, # 牛市中超卖阈值提高
dataframe['rsi_3m'] < 45 # 非牛市中超卖阈值降低
)
# 检测剧烈拉升
is_unstable_region, unstable_time = self.detect_h1_rapid_rise(pair, dataframe, metadata)
# 其他入场条件保持不变
bb_condition = dataframe['close'] < dataframe['bb_lower_3m'] # 价格在布林带下轨下方
stochrsi_condition = dataframe['stochrsi_k_3m'] < 20 # StochRSI超卖
macd_condition = dataframe['macd_3m'] > dataframe['macd_signal_3m'] # MACD金叉
volume_condition = dataframe['volume'] > dataframe['volume_ma'] * 1.5 # 成交量放大
bb_width_condition = (dataframe['bb_upper_3m'] - dataframe['bb_lower_3m']) > dataframe['bb_lower_3m'] * 0.02 # 布林带宽度足够
trend_condition = dataframe['close'] > dataframe['ema_50_3m'] # 短期趋势向上
bullish_candle_condition = dataframe['close'] > dataframe['open'] # 收阳
# 如果检测到剧烈拉升且未过回看期,则阻止入场信号
if is_unstable_region:
dataframe['enter_long'] = 0
logger.info(f"[{pair}] 处于不稳固区域,阻止入场信号")
return dataframe
# 市场状态已在 populate_indicators 中计算,无需重复获取
# 组合条件至少满足3个基本条件
# 计算满足的条件数量
dataframe['satisfied_conditions'] = 0
for condition in [rsi_condition, bb_condition, stochrsi_condition, macd_condition,
volume_condition, trend_condition]:
dataframe['satisfied_conditions'] += condition.astype(int)
# 条件1: 价格接近布林带下轨(允许一定偏差)
close_to_bb_lower_1h = (dataframe['close'] <= dataframe['bb_lower_1h'] * 1.03) # 放宽到3%偏差
# 条件2: RSI 不高于阈值(根据市场状态动态调整)
rsi_threshold = np.where(
dataframe['market_state'].isin(['strong_bull', 'weak_bull']), 50, 45
# 入场信号
dataframe['enter_long'] = (
(dataframe['satisfied_conditions'] >= 3) & # 至少满足3个条件
~self.detect_h1_rapid_rise(dataframe, metadata) # 非剧烈拉升后的回调
)
rsi_condition_1h = dataframe['rsi_1h'] < rsi_threshold
# 条件3: StochRSI 处于超卖区域(根据市场状态动态调整)
stochrsi_threshold = np.where(
dataframe['market_state'].isin(['strong_bull', 'weak_bull']), 35, 25
)
stochrsi_condition_1h = (dataframe['stochrsi_k_1h'] < stochrsi_threshold) & (dataframe['stochrsi_d_1h'] < stochrsi_threshold)
# 条件4: MACD 上升趋势
macd_condition_1h = dataframe['macd_1h'] > dataframe['macd_signal_1h']
# 条件5: 成交量显著放大(可选条件)
volume_spike = dataframe['volume'] > dataframe['volume_ma'] * 1.5
# 条件6: 布林带宽度过滤(避免窄幅震荡)
bb_width = (dataframe['bb_upper_1h'] - dataframe['bb_lower_1h']) / dataframe['close']
bb_width_condition = bb_width > 0.02 # 布林带宽度大于2%
# 辅助条件: 3m 和 15m 趋势确认(允许部分时间框架不一致)
trend_confirmation = (dataframe['trend_3m'] == 1) | (dataframe['trend_15m'] == 1)
# 合并所有条件(减少强制性条件)
# 至少满足5个条件中的3个
condition_count = (
close_to_bb_lower_1h.astype(int) +
rsi_condition_1h.astype(int) +
stochrsi_condition_1h.astype(int) +
macd_condition_1h.astype(int) +
(volume_spike | bb_width_condition).astype(int) + # 成交量或布林带宽度满足其一即可
trend_confirmation.astype(int)
)
final_condition = condition_count >= 3
# 设置入场信号
dataframe.loc[final_condition, 'enter_long'] = 1
# 增强调试信息
logger.info(f"[{metadata['pair']}] 入场条件检查:")
logger.info(f" - 价格接近布林带下轨: {close_to_bb_lower_1h.sum()}")
logger.info(f" - RSI 超卖: {rsi_condition_1h.sum()}")
logger.info(f" - StochRSI 超卖: {stochrsi_condition_1h.sum()}")
logger.info(f" - MACD 上升趋势: {macd_condition_1h.sum()}")
logger.info(f" - 成交量或布林带宽度: {(volume_spike | bb_width_condition).sum()}")
logger.info(f" - 趋势确认: {trend_confirmation.sum()}")
logger.info(f" - 最终条件: {final_condition.sum()}")
# 日志记录
if dataframe['enter_long'].sum() > 0:
logger.info(f"[{metadata['pair']}] 发现入场信号数量: {dataframe['enter_long'].sum()}")
# 在入场点添加入场标签,用于可视化和调试
if 'enter_long' in dataframe.columns:
entry_points = dataframe[dataframe['enter_long']].index
if len(entry_points) > 0:
logger.info(f"[{metadata['pair']}] 检测到 {len(entry_points)} 个入场信号")
# 记录每个入场信号的详细信息
for i in entry_points[-3:]: # 只记录最近3个信号
row = dataframe.iloc[i]
logger.debug(f"[{metadata['pair']}] 入场信号 - 时间: {row.name}, \
价格: {row['close']:.4f}, \
RSI: {row['rsi_3m']:.2f}, \
满足条件数: {row['satisfied_conditions']}")
return dataframe