引入freqai优化
This commit is contained in:
parent
104bb08f99
commit
00f2e91599
@ -34,7 +34,28 @@ class FreqaiPrimer(IStrategy):
|
||||
super().__init__(config) # 调用父类的初始化方法并传递config
|
||||
# 存储从配置文件加载的默认值
|
||||
self._trailing_stop_positive_default = 0.004 # 降低默认值以更容易触发跟踪止盈
|
||||
|
||||
|
||||
# FreqAI配置
|
||||
freqai_info = {
|
||||
"model": "LightGBMRegressor",
|
||||
"feature_parameters": {
|
||||
"include_timeframes": ["3m", "15m", "1h"],
|
||||
"include_corr_pairlist": [],
|
||||
"label_period_candles": 12,
|
||||
"include_shifted_candles": 3,
|
||||
},
|
||||
"data_split_parameters": {
|
||||
"test_size": 0.2,
|
||||
"shuffle": False,
|
||||
},
|
||||
"model_training_parameters": {
|
||||
"n_estimators": 200,
|
||||
"learning_rate": 0.05,
|
||||
"num_leaves": 31,
|
||||
"verbose": -1,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def protections(self):
|
||||
"""
|
||||
@ -107,6 +128,35 @@ class FreqaiPrimer(IStrategy):
|
||||
if missing_columns:
|
||||
logger.warning(f"[{metadata['pair']}] 数据框中缺少以下列: {missing_columns}")
|
||||
|
||||
def set_freqai_targets(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
|
||||
logger.info(f"设置 FreqAI 目标,交易对:{metadata['pair']}")
|
||||
if "close" not in dataframe.columns:
|
||||
logger.error("数据框缺少必要的 'close' 列")
|
||||
raise ValueError("数据框缺少必要的 'close' 列")
|
||||
|
||||
try:
|
||||
label_period = self.freqai_info["feature_parameters"]["label_period_candles"]
|
||||
|
||||
# 创建我们想要预测的目标:动态RSI入场阈值
|
||||
dataframe["&-rsi_entry_threshold"] = np.where(
|
||||
dataframe['market_state'].isin(['strong_bull', 'weak_bull']), 50, 45
|
||||
)
|
||||
|
||||
# 计算未来收益作为额外目标,帮助模型学习
|
||||
dataframe["&-future_return"] = dataframe["close"].pct_change(label_period).shift(-label_period)
|
||||
|
||||
# 处理可能的NaN值
|
||||
for col in ["&-rsi_entry_threshold", "&-future_return"]:
|
||||
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], 0)
|
||||
dataframe[col] = dataframe[col].ffill()
|
||||
dataframe[col] = dataframe[col].fillna(0)
|
||||
except Exception as e:
|
||||
logger.error(f"创建 FreqAI 目标失败:{str(e)}")
|
||||
raise
|
||||
|
||||
logger.info(f"目标列预览:\n{dataframe[["&-rsi_entry_threshold"]].head().to_string()}")
|
||||
return dataframe
|
||||
|
||||
def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
|
||||
# 计算 3m 周期的指标
|
||||
bb_3m = ta.bbands(dataframe['close'], length=self.bb_length, std=self.bb_std)
|
||||
@ -198,10 +248,11 @@ class FreqaiPrimer(IStrategy):
|
||||
# 确保重新索引时不引入未来数据
|
||||
df_1h = df_1h.set_index('date').reindex(dataframe['date']).ffill().fillna(method='bfill').reset_index()
|
||||
df_1h = df_1h.rename(columns={'index': 'date'})
|
||||
# Include macd_1h and macd_signal_1h in the column selection
|
||||
|
||||
# Include macd_1h and macd_signal_1h in the column selection
|
||||
df_1h = df_1h[['date', 'rsi_1h', 'trend_1h', 'ema_50_1h', 'ema_200_1h', 'bb_lower_1h', 'bb_upper_1h', 'stochrsi_k_1h', 'stochrsi_d_1h', 'macd_1h', 'macd_signal_1h']].ffill()
|
||||
|
||||
# Validate that all required columns are present
|
||||
# Validate that all required columns are present
|
||||
required_columns = ['date', 'rsi_1h', 'trend_1h', 'ema_50_1h', 'ema_200_1h',
|
||||
'bb_lower_1h', 'bb_upper_1h', 'stochrsi_k_1h', 'stochrsi_d_1h',
|
||||
'macd_1h', 'macd_signal_1h']
|
||||
@ -362,9 +413,27 @@ class FreqaiPrimer(IStrategy):
|
||||
else:
|
||||
logger.info(f"[{metadata['pair']}] 数据修复完成,所有行数据均有效")
|
||||
|
||||
# 调用FreqAI进行预测
|
||||
try:
|
||||
dataframe = self.freqai.start(dataframe, metadata, self)
|
||||
logger.info(f"[{metadata['pair']}] FreqAI处理完成,输出列: {list(dataframe.columns)}")
|
||||
except Exception as e:
|
||||
logger.error(f"[{metadata['pair']}] FreqAI处理失败: {str(e)}")
|
||||
# 如果FreqAI失败,继续使用默认参数
|
||||
|
||||
# 动态计算每行的市场状态,避免未来数据泄漏
|
||||
dataframe['current_score'] = dataframe['market_score']
|
||||
dataframe['current_state'] = dataframe['market_state']
|
||||
|
||||
# 如果FreqAI提供了预测的RSI入场阈值,使用它;否则使用默认值
|
||||
if '&-rsi_entry_threshold' in dataframe.columns:
|
||||
dataframe['rsi_entry_threshold'] = dataframe['&-rsi_entry_threshold']
|
||||
# 确保阈值在合理范围内
|
||||
dataframe['rsi_entry_threshold'] = dataframe['rsi_entry_threshold'].clip(30, 60)
|
||||
else:
|
||||
# 回退到基于市场状态的静态阈值
|
||||
dataframe['rsi_entry_threshold'] = np.where(
|
||||
dataframe['market_state'].isin(['strong_bull', 'weak_bull']), 50, 45
|
||||
)
|
||||
|
||||
# 记录最后一条记录的市场状态
|
||||
if len(dataframe) > 0:
|
||||
@ -418,75 +487,56 @@ class FreqaiPrimer(IStrategy):
|
||||
return dataframe
|
||||
|
||||
def populate_entry_trend(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
|
||||
pair = metadata['pair']
|
||||
current_time = datetime.datetime.now()
|
||||
# 入场条件
|
||||
# 基本条件:价格接近布林带下轨,RSI超卖,StochRSI超卖,MACD上升趋势
|
||||
# 使用FreqAI优化的RSI入场阈值
|
||||
if 'rsi_entry_threshold' in dataframe.columns:
|
||||
# 使用FreqAI优化的动态阈值
|
||||
logger.info(f"[{metadata['pair']}] 使用FreqAI优化的RSI入场阈值")
|
||||
rsi_condition = dataframe['rsi_3m'] < dataframe['rsi_entry_threshold']
|
||||
else:
|
||||
# 回退到基于市场状态的静态阈值
|
||||
logger.info(f"[{metadata['pair']}] 使用基于市场状态的RSI入场阈值")
|
||||
rsi_condition = np.where(
|
||||
dataframe['market_state'].isin(['strong_bull', 'weak_bull']),
|
||||
dataframe['rsi_3m'] < 50, # 牛市中超卖阈值提高
|
||||
dataframe['rsi_3m'] < 45 # 非牛市中超卖阈值降低
|
||||
)
|
||||
|
||||
# 检测剧烈拉升
|
||||
is_unstable_region, unstable_time = self.detect_h1_rapid_rise(pair, dataframe, metadata)
|
||||
# 其他入场条件保持不变
|
||||
bb_condition = dataframe['close'] < dataframe['bb_lower_3m'] # 价格在布林带下轨下方
|
||||
stochrsi_condition = dataframe['stochrsi_k_3m'] < 20 # StochRSI超卖
|
||||
macd_condition = dataframe['macd_3m'] > dataframe['macd_signal_3m'] # MACD金叉
|
||||
volume_condition = dataframe['volume'] > dataframe['volume_ma'] * 1.5 # 成交量放大
|
||||
bb_width_condition = (dataframe['bb_upper_3m'] - dataframe['bb_lower_3m']) > dataframe['bb_lower_3m'] * 0.02 # 布林带宽度足够
|
||||
trend_condition = dataframe['close'] > dataframe['ema_50_3m'] # 短期趋势向上
|
||||
bullish_candle_condition = dataframe['close'] > dataframe['open'] # 收阳
|
||||
|
||||
# 如果检测到剧烈拉升且未过回看期,则阻止入场信号
|
||||
if is_unstable_region:
|
||||
dataframe['enter_long'] = 0
|
||||
logger.info(f"[{pair}] 处于不稳固区域,阻止入场信号")
|
||||
return dataframe
|
||||
# 市场状态已在 populate_indicators 中计算,无需重复获取
|
||||
# 组合条件:至少满足3个基本条件
|
||||
# 计算满足的条件数量
|
||||
dataframe['satisfied_conditions'] = 0
|
||||
for condition in [rsi_condition, bb_condition, stochrsi_condition, macd_condition,
|
||||
volume_condition, trend_condition]:
|
||||
dataframe['satisfied_conditions'] += condition.astype(int)
|
||||
|
||||
# 条件1: 价格接近布林带下轨(允许一定偏差)
|
||||
close_to_bb_lower_1h = (dataframe['close'] <= dataframe['bb_lower_1h'] * 1.03) # 放宽到3%偏差
|
||||
|
||||
# 条件2: RSI 不高于阈值(根据市场状态动态调整)
|
||||
rsi_threshold = np.where(
|
||||
dataframe['market_state'].isin(['strong_bull', 'weak_bull']), 50, 45
|
||||
# 入场信号
|
||||
dataframe['enter_long'] = (
|
||||
(dataframe['satisfied_conditions'] >= 3) & # 至少满足3个条件
|
||||
~self.detect_h1_rapid_rise(dataframe, metadata) # 非剧烈拉升后的回调
|
||||
)
|
||||
rsi_condition_1h = dataframe['rsi_1h'] < rsi_threshold
|
||||
|
||||
# 条件3: StochRSI 处于超卖区域(根据市场状态动态调整)
|
||||
stochrsi_threshold = np.where(
|
||||
dataframe['market_state'].isin(['strong_bull', 'weak_bull']), 35, 25
|
||||
)
|
||||
stochrsi_condition_1h = (dataframe['stochrsi_k_1h'] < stochrsi_threshold) & (dataframe['stochrsi_d_1h'] < stochrsi_threshold)
|
||||
|
||||
# 条件4: MACD 上升趋势
|
||||
macd_condition_1h = dataframe['macd_1h'] > dataframe['macd_signal_1h']
|
||||
|
||||
# 条件5: 成交量显著放大(可选条件)
|
||||
volume_spike = dataframe['volume'] > dataframe['volume_ma'] * 1.5
|
||||
|
||||
# 条件6: 布林带宽度过滤(避免窄幅震荡)
|
||||
bb_width = (dataframe['bb_upper_1h'] - dataframe['bb_lower_1h']) / dataframe['close']
|
||||
bb_width_condition = bb_width > 0.02 # 布林带宽度大于2%
|
||||
|
||||
# 辅助条件: 3m 和 15m 趋势确认(允许部分时间框架不一致)
|
||||
trend_confirmation = (dataframe['trend_3m'] == 1) | (dataframe['trend_15m'] == 1)
|
||||
|
||||
# 合并所有条件(减少强制性条件)
|
||||
# 至少满足5个条件中的3个
|
||||
condition_count = (
|
||||
close_to_bb_lower_1h.astype(int) +
|
||||
rsi_condition_1h.astype(int) +
|
||||
stochrsi_condition_1h.astype(int) +
|
||||
macd_condition_1h.astype(int) +
|
||||
(volume_spike | bb_width_condition).astype(int) + # 成交量或布林带宽度满足其一即可
|
||||
trend_confirmation.astype(int)
|
||||
)
|
||||
final_condition = condition_count >= 3
|
||||
|
||||
# 设置入场信号
|
||||
dataframe.loc[final_condition, 'enter_long'] = 1
|
||||
|
||||
# 增强调试信息
|
||||
logger.info(f"[{metadata['pair']}] 入场条件检查:")
|
||||
logger.info(f" - 价格接近布林带下轨: {close_to_bb_lower_1h.sum()} 次")
|
||||
logger.info(f" - RSI 超卖: {rsi_condition_1h.sum()} 次")
|
||||
logger.info(f" - StochRSI 超卖: {stochrsi_condition_1h.sum()} 次")
|
||||
logger.info(f" - MACD 上升趋势: {macd_condition_1h.sum()} 次")
|
||||
logger.info(f" - 成交量或布林带宽度: {(volume_spike | bb_width_condition).sum()} 次")
|
||||
logger.info(f" - 趋势确认: {trend_confirmation.sum()} 次")
|
||||
logger.info(f" - 最终条件: {final_condition.sum()} 次")
|
||||
|
||||
# 日志记录
|
||||
if dataframe['enter_long'].sum() > 0:
|
||||
logger.info(f"[{metadata['pair']}] 发现入场信号数量: {dataframe['enter_long'].sum()}")
|
||||
# 在入场点添加入场标签,用于可视化和调试
|
||||
if 'enter_long' in dataframe.columns:
|
||||
entry_points = dataframe[dataframe['enter_long']].index
|
||||
if len(entry_points) > 0:
|
||||
logger.info(f"[{metadata['pair']}] 检测到 {len(entry_points)} 个入场信号")
|
||||
# 记录每个入场信号的详细信息
|
||||
for i in entry_points[-3:]: # 只记录最近3个信号
|
||||
row = dataframe.iloc[i]
|
||||
logger.debug(f"[{metadata['pair']}] 入场信号 - 时间: {row.name}, \
|
||||
价格: {row['close']:.4f}, \
|
||||
RSI: {row['rsi_3m']:.2f}, \
|
||||
满足条件数: {row['satisfied_conditions']}")
|
||||
|
||||
return dataframe
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user