426 lines
18 KiB
Python
426 lines
18 KiB
Python
import logging
|
||
from freqtrade.strategy import IStrategy
|
||
from pandas import DataFrame
|
||
import pandas as pd
|
||
import numpy as np
|
||
import talib as ta
|
||
import datetime
|
||
from typing import Dict, List, Optional
|
||
from sklearn.metrics import mean_squared_error
|
||
from freqtrade.strategy import CategoricalParameter, DecimalParameter
|
||
from xgboost import XGBRegressor
|
||
import ccxt
|
||
|
||
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class OKXRegressionStrategy(IStrategy):
|
||
"""
|
||
Freqtrade AI 策略,使用回归模型进行 OKX 数据上的仅做多交易。
|
||
- 数据通过 CCXT 从 OKX 交易所获取。
|
||
- 使用 XGBoost 回归模型预测价格变化。
|
||
- 仅生成做多(买入)信号,不做空。
|
||
- 适配 Freqtrade 2025.3,继承 IStrategy。
|
||
"""
|
||
|
||
# 指标所需的最大启动蜡烛数
|
||
startup_candle_count: int = 20
|
||
|
||
# 策略元数据(建议通过 config.json 配置)
|
||
trailing_stop = True
|
||
trailing_stop_positive = 0.01
|
||
max_open_trades = 3
|
||
stake_amount = 'dynamic'
|
||
atr_period = CategoricalParameter([7, 14, 21], default=14, space='buy')
|
||
atr_multiplier = DecimalParameter(1.0, 3.0, default=2.0, space='sell')
|
||
|
||
# FreqAI 配置
|
||
freqai_config = {
|
||
"enabled": True,
|
||
"identifier": "okx_regression_v1",
|
||
"model_training_parameters": {
|
||
"n_estimators": 100,
|
||
"learning_rate": 0.05,
|
||
"max_depth": 6
|
||
},
|
||
"feature_parameters": {
|
||
"include_timeframes": ["5m", "15m", "1h"],
|
||
"include_corr_pairlist": ["BTC/USDT", "ETH/USDT"],
|
||
"label_period_candles": 12,
|
||
"include_shifted_candles": 2, # 添加历史偏移特征
|
||
"principal_component_analysis": True # 启用 PCA
|
||
},
|
||
"data_split_parameters": {
|
||
"test_size": 0.2,
|
||
"random_state": 42,
|
||
"shuffle": False
|
||
},
|
||
"train_period_days": 90,
|
||
"backtest_period_days": 30,
|
||
"purge_old_models": True # 清理旧模型
|
||
}
|
||
|
||
def __init__(self, config: Dict):
|
||
super().__init__(config)
|
||
# 初始化特征缓存
|
||
self.feature_cache = {}
|
||
|
||
|
||
def feature_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: Dict, **kwargs) -> DataFrame:
|
||
"""
|
||
为每个时间框架和相关交易对生成特征。
|
||
"""
|
||
cache_key = f"{metadata.get('pair', 'unknown')}_{period}"
|
||
if cache_key in self.feature_cache:
|
||
logger.debug(f"使用缓存特征:{cache_key}")
|
||
return self.feature_cache[cache_key]
|
||
|
||
# RSI
|
||
dataframe[f"%-%-rsi-{period}"] = ta.RSI(dataframe["close"], timeperiod=period)
|
||
|
||
# MACD
|
||
macd, macdsignal, _ = ta.MACD(dataframe["close"], fastperiod=12, slowperiod=26, signalperiod=9)
|
||
dataframe[f"%-%-macd-{period}"] = macd
|
||
dataframe[f"%-%-macdsignal-{period}"] = macdsignal
|
||
|
||
# 布林带宽度
|
||
upper, middle, lower = ta.BBANDS(dataframe["close"], timeperiod=period)
|
||
dataframe[f"%-%-bb_width-{period}"] = (upper - lower) / middle
|
||
|
||
# 成交量均线
|
||
dataframe[f"%-%-volume_ma-{period}"] = ta.SMA(dataframe["volume"], timeperiod=period)
|
||
|
||
# 仅为 BTC/USDT 和 ETH/USDT 生成 order_book_imbalance
|
||
#
|
||
pair = metadata.get('pair', 'unknown')
|
||
# 注释掉订单簿相关代码
|
||
# if pair in ["BTC/USDT", "ETH/USDT"]:
|
||
# try:
|
||
# order_book = self.fetch_okx_order_book(pair)
|
||
# dataframe[f"%-%-order_book_imbalance"] = (
|
||
# order_book["bids"] - order_book["asks"]
|
||
# ) / (order_book["bids"] + order_book["asks"] + 1e-10)
|
||
# except Exception as e:
|
||
# logger.warning(f"Failed to fetch order book for {pair}: {str(e)}")
|
||
# dataframe[f"%-%-order_book_imbalance"] = 0.0
|
||
|
||
# 数据清洗
|
||
dataframe = dataframe.replace([np.inf, -np.inf], np.nan)
|
||
dataframe = dataframe.ffill().fillna(0)
|
||
|
||
# 缓存特征
|
||
self.feature_cache[cache_key] = dataframe.copy()
|
||
logger.debug(f"周期 {period} 特征:{list(dataframe.filter(like='%-%-').columns)}")
|
||
|
||
return dataframe
|
||
|
||
def feature_engineering_standard(self, dataframe: DataFrame, **kwargs) -> DataFrame:
|
||
"""
|
||
添加基础时间框架的全局特征。
|
||
"""
|
||
# 确保索引是 DatetimeIndex
|
||
if not isinstance(dataframe.index, pd.DatetimeIndex):
|
||
dataframe = dataframe.set_index(pd.DatetimeIndex(dataframe.index))
|
||
|
||
# 价格变化率
|
||
dataframe["%-price_change"] = dataframe["close"].pct_change()
|
||
|
||
# 时间特征:小时
|
||
dataframe["%-hour_of_day"] = dataframe.index.hour / 24.0
|
||
|
||
# 数据清洗
|
||
dataframe = dataframe.replace([np.inf, -np.inf], np.nan)
|
||
dataframe = dataframe.ffill().fillna(0)
|
||
|
||
logger.debug(f"全局特征:{list(dataframe.filter(like='%-%-').columns)}")
|
||
return dataframe
|
||
|
||
def set_freqai_targets(self, dataframe: DataFrame, metadata: Dict, **kwargs) -> DataFrame:
|
||
"""
|
||
设置回归模型的目标变量,为不同币对设置动态止损和ROI阈值。
|
||
输入:dataframe(K线数据:close, high, low),metadata(交易对信息),config(FreqAI配置)
|
||
输出:更新后的dataframe,包含目标标签
|
||
"""
|
||
# 获取配置参数
|
||
label_period = self.freqai_config["feature_parameters"]["label_period_candles"] # 标签预测周期(如5分钟K线的N根)
|
||
pair = metadata["pair"] # 当前交易对(如DOGE/USDT)
|
||
|
||
# 计算未来价格变化率(现有逻辑)
|
||
dataframe["&-s_close"] = (dataframe["close"].shift(-label_period) - dataframe["close"]) / dataframe["close"]
|
||
|
||
# 计算不同时间窗口的ROI(现有逻辑)
|
||
for minutes in [0, 15, 30]:
|
||
candles = int(minutes / 5) # 假设5分钟K线
|
||
if candles > 0:
|
||
dataframe[f"&-roi_{minutes}"] = (dataframe["close"].shift(-candles) - dataframe["close"]) / dataframe["close"]
|
||
else:
|
||
dataframe[f"&-roi_{minutes}"] = 0.0
|
||
|
||
# 计算市场状态指标:ADX(14周期,与label_period_candles对齐)
|
||
dataframe["adx"] = ta.ADX(dataframe["high"], dataframe["low"], dataframe["close"], timeperiod=14)
|
||
|
||
# 定义币对特定的ADX阈值和止损/ROI范围
|
||
pair_thresholds = {
|
||
"DOGE/USDT": {
|
||
"adx_trend": 20, # 趋势市场ADX阈值
|
||
"adx_oscillation": 15, # 震荡市场ADX阈值
|
||
"stoploss_trend": -0.08, # 趋势市场止损:-8%
|
||
"stoploss_oscillation": -0.04, # 震荡市场止损:-4%
|
||
"stoploss_mid": -0.06, # 中间状态止损:-6%
|
||
"roi_trend": 0.06, # 趋势市场ROI:6%
|
||
"roi_oscillation": 0.025, # 震荡市场ROI:2.5%
|
||
"roi_mid": 0.04 # 中间状态ROI:4%
|
||
},
|
||
"BTC/USDT": {
|
||
"adx_trend": 25,
|
||
"adx_oscillation": 20,
|
||
"stoploss_trend": -0.03,
|
||
"stoploss_oscillation": -0.015,
|
||
"stoploss_mid": -0.02,
|
||
"roi_trend": 0.03,
|
||
"roi_oscillation": 0.015,
|
||
"roi_mid": 0.02
|
||
},
|
||
"SOL/USDT": {
|
||
"adx_trend": 22,
|
||
"adx_oscillation": 18,
|
||
"stoploss_trend": -0.06,
|
||
"stoploss_oscillation": -0.03,
|
||
"stoploss_mid": -0.045,
|
||
"roi_trend": 0.045,
|
||
"roi_oscillation": 0.02,
|
||
"roi_mid": 0.03
|
||
},
|
||
"XRP/USDT": {
|
||
"adx_trend": 22,
|
||
"adx_oscillation": 18,
|
||
"stoploss_trend": -0.06,
|
||
"stoploss_oscillation": -0.03,
|
||
"stoploss_mid": -0.045,
|
||
"roi_trend": 0.045,
|
||
"roi_oscillation": 0.02,
|
||
"roi_mid": 0.03
|
||
}
|
||
}
|
||
|
||
# 动态化 &-stoploss_pred(基于市场状态和币对)
|
||
dataframe["&-stoploss_pred"] = 0.0
|
||
for index, row in dataframe.iterrows():
|
||
thresholds = pair_thresholds.get(pair, {})
|
||
if not thresholds:
|
||
continue
|
||
adx_value = row["adx"]
|
||
if adx_value > thresholds["adx_trend"]: # 趋势市场
|
||
dataframe.at[index, "&-stoploss_pred"] = thresholds["stoploss_trend"] # 宽松止损
|
||
elif adx_value < thresholds["adx_oscillation"]: # 震荡市场
|
||
dataframe.at[index, "&-stoploss_pred"] = thresholds["stoploss_oscillation"] # 严格止损
|
||
else: # 中间状态
|
||
dataframe.at[index, "&-stoploss_pred"] = thresholds["stoploss_mid"] # 中等止损
|
||
# 风险控制:设置止损下限
|
||
if dataframe.at[index, "&-stoploss_pred"] < -0.10:
|
||
dataframe.at[index, "&-stoploss_pred"] = -0.10
|
||
|
||
# 动态化 &-roi_0_pred(基于市场趋势和币对)
|
||
dataframe["&-roi_0_pred"] = 0.0
|
||
for index, row in dataframe.iterrows():
|
||
thresholds = pair_thresholds.get(pair, {})
|
||
if not thresholds:
|
||
continue
|
||
adx_value = row["adx"]
|
||
if adx_value > thresholds["adx_trend"]: # 强趋势市场
|
||
dataframe.at[index, "&-roi_0_pred"] = thresholds["roi_trend"] # 高ROI
|
||
elif adx_value < thresholds["adx_oscillation"]: # 震荡市场
|
||
dataframe.at[index, "&-roi_0_pred"] = thresholds["roi_oscillation"] # 低ROI
|
||
else: # 中间状态
|
||
dataframe.at[index, "&-roi_0_pred"] = thresholds["roi_mid"] # 中等ROI
|
||
# 风险控制:设置ROI上限
|
||
if dataframe.at[index, "&-roi_0_pred"] > 0.10:
|
||
dataframe.at[index, "&-roi_0_pred"] = 0.10
|
||
|
||
# 计算RSI预测(现有逻辑)
|
||
dataframe["&-buy_rsi_pred"] = ta.RSI(dataframe["close"], timeperiod=14).rolling(20).mean()
|
||
|
||
# 数据清洗
|
||
dataframe = dataframe.replace([np.inf, -np.inf], np.nan)
|
||
dataframe = dataframe.fillna(method="ffill").fillna(0)
|
||
|
||
# 验证目标
|
||
required_targets = ["&-s_close", "&-roi_0", "&-buy_rsi_pred", "&-stoploss_pred", "&-roi_0_pred"]
|
||
missing_targets = [col for col in required_targets if col not in dataframe.columns]
|
||
if missing_targets:
|
||
logger.error(f"缺少目标列:{missing_targets}")
|
||
raise ValueError(f"目标初始化失败:{missing_targets}")
|
||
|
||
logger.debug(f"目标初始化完成。DataFrame 形状:{dataframe.shape}")
|
||
return dataframe
|
||
|
||
def populate_indicators(self, dataframe: DataFrame, metadata: Dict) -> DataFrame:
|
||
"""
|
||
使用 FreqAI 生成指标和预测。
|
||
"""
|
||
try:
|
||
logger.debug(f"FreqAI 对象:{type(self.freqai)}")
|
||
dataframe = self.freqai.start(dataframe, metadata, self)
|
||
|
||
# 验证数据完整性
|
||
if dataframe["close"].isna().any() or dataframe["volume"].isna().any():
|
||
logger.warning("检测到 OKX 数据缺失,使用前向填充")
|
||
dataframe = dataframe.ffill().fillna(0)
|
||
|
||
# 预测统计
|
||
if "&-s_close" in dataframe.columns:
|
||
logger.debug(f"预测统计:均值={dataframe['&-s_close'].mean():.4f}, "
|
||
f"方差={dataframe['&-s_close'].var():.4f}")
|
||
|
||
logger.debug(f"生成的列:{list(dataframe.columns)}")
|
||
return dataframe
|
||
except Exception as e:
|
||
logger.error(f"FreqAI start 失败:{str(e)}")
|
||
raise
|
||
finally:
|
||
logger.debug("populate_indicators 完成")
|
||
|
||
# 确保返回 DataFrame,防止 None
|
||
if dataframe is None:
|
||
dataframe = DataFrame()
|
||
dataframe['ATR_{}'.format(self.atr_period.value)] = ta.ATR(dataframe['high'], dataframe['low'], dataframe['close'], timeperiod=self.atr_period.value)
|
||
return dataframe
|
||
|
||
|
||
def populate_entry_trend(self, dataframe: DataFrame, metadata: Dict) -> DataFrame:
|
||
"""
|
||
基于回归预测生成仅做多信号。
|
||
"""
|
||
# 确保 "%-%-rsi-14" 列存在
|
||
if "%-%-rsi-14" not in dataframe.columns:
|
||
dataframe["%-%-rsi-14"] = 0.0
|
||
|
||
dataframe.loc[
|
||
(
|
||
(dataframe["&-s_close"] > 0.01) & # 预测价格上涨 > 1%
|
||
(dataframe["do_predict"] == 1) & # 预测可靠
|
||
(dataframe["%-%-rsi-14"] < dataframe["&-buy_rsi_pred"]) # RSI 低于动态阈值
|
||
),
|
||
"enter_long"] = 1
|
||
|
||
# 设置 entry_price 列,用于止损逻辑
|
||
dataframe['entry_price'] = dataframe['open'].where(dataframe['enter_long'] == 1).ffill()
|
||
|
||
logger.debug(f"生成 {dataframe['enter_long'].sum()} 个做多信号")
|
||
return dataframe
|
||
def _dynamic_stop_loss(self, dataframe: DataFrame, metadata: dict, atr_col: str = 'ATR_14', multiplier: float = 2.0) -> DataFrame:
|
||
"""
|
||
封装动态止损逻辑,基于入场价和ATR计算止损线
|
||
:param dataframe: 原始DataFrame
|
||
:param metadata: 策略元数据
|
||
:param atr_col: 使用的ATR列名
|
||
:param multiplier: ATR乘数
|
||
:return: 更新后的DataFrame
|
||
"""
|
||
dataframe['entry_price'] = dataframe['open'].where(dataframe['enter_long'] == 1).ffill()
|
||
dataframe['stop_loss_line'] = dataframe['entry_price'] - dataframe[atr_col] * multiplier
|
||
|
||
dataframe.loc[
|
||
(dataframe['close'] < dataframe['stop_loss_line']),
|
||
'exit_long'
|
||
] = 1
|
||
|
||
return dataframe
|
||
def populate_exit_trend(self, dataframe: DataFrame, metadata: Dict) -> DataFrame:
|
||
# 确保 ATR 列存在
|
||
if 'ATR_14' not in dataframe.columns:
|
||
dataframe['ATR_14'] = 0.0
|
||
|
||
# 计算动态止损线
|
||
dataframe['stop_loss_line'] = dataframe['entry_price'] - (dataframe['ATR_14'] * 2)
|
||
|
||
# 发送止损信息
|
||
self.dp.send_msg(f"ATR: {dataframe['ATR_14'].iloc[-1]:.5f}, Stop Loss Line: {dataframe['stop_loss_line'].iloc[-1]:.5f}")
|
||
|
||
# 应用动态止损逻辑
|
||
return self._dynamic_stop_loss(dataframe, metadata)
|
||
|
||
def custom_stake_amount(self, pair: str, current_time: 'datetime', current_rate: float,
|
||
proposed_stake: float, min_stake: float, max_stake: float,
|
||
entry_tag: Optional[str], **kwargs) -> float:
|
||
"""
|
||
动态下注:每笔交易占账户余额 2%。
|
||
"""
|
||
balance = self.wallets.get_available_stake_amount()
|
||
stake = balance * 0.02
|
||
return min(max(stake, min_stake), max_stake)
|
||
def custom_stoploss(self, pair: str, trade: 'Trade', current_time: datetime,
|
||
current_rate: float, profit_percent: float,
|
||
after_fill: bool, **kwargs) -> Optional[float]:
|
||
|
||
if trade.enter_tag == 'long':
|
||
atr_value = self.dp.get_pair_dataframe(pair, timeframe=self.timeframe)['ATR_14'].iloc[-1]
|
||
trailing_stop = current_rate - atr_value * 1.5
|
||
return trailing_stop / current_rate - 1 # 返回相对百分比
|
||
return None
|
||
|
||
def leverage(self, pair: str, current_time: 'datetime', current_rate: float,
|
||
proposed_leverage: float, max_leverage: float, side: str,
|
||
**kwargs) -> float:
|
||
"""
|
||
禁用杠杆,仅做多。
|
||
"""
|
||
return 1.0
|
||
|
||
def confirm_trade_entry(self, pair: str, order_type: str, amount: float, rate: float,
|
||
time_in_force: str, current_time: 'datetime', **kwargs) -> bool:
|
||
"""
|
||
验证交易进入,检查 OKX 数据新鲜度。
|
||
"""
|
||
if not self.check_data_freshness(pair, current_time):
|
||
logger.warning(f"{pair} 的 OKX 数据过期,跳过交易")
|
||
return False
|
||
return True
|
||
|
||
def check_data_freshness(self, pair: str, current_time: 'datetime') -> bool:
|
||
"""
|
||
简化版数据新鲜度检查,不依赖外部 API。
|
||
"""
|
||
# 假设数据总是新鲜的(用于测试)
|
||
return True
|
||
|
||
|
||
|
||
def fit(self, data_dictionary: Dict, metadata: Dict, **kwargs) -> None:
|
||
"""
|
||
训练回归模型并记录性能。
|
||
"""
|
||
try:
|
||
# 初始化模型
|
||
if not hasattr(self, 'model') or self.model is None:
|
||
model_params = self.freqai_config["model_training_parameters"]
|
||
self.model = XGBRegressor(**model_params)
|
||
logger.debug("初始化新的 XGBoost 回归模型")
|
||
|
||
# 调用 FreqAI 训练
|
||
self.freqai.fit(data_dictionary, metadata, **kwargs)
|
||
|
||
# 记录训练集性能
|
||
train_data = data_dictionary["train_features"]
|
||
train_labels = data_dictionary["train_labels"]
|
||
train_predictions = self.model.predict(train_data)
|
||
train_mse = mean_squared_error(train_labels, train_predictions)
|
||
logger.info(f"训练集 MSE:{train_mse:.6f}")
|
||
|
||
# 记录测试集性能(如果可用)
|
||
if "test_features" in data_dictionary:
|
||
test_data = data_dictionary["test_features"]
|
||
test_labels = data_dictionary["test_labels"]
|
||
test_predictions = self.model.predict(test_data)
|
||
test_mse = mean_squared_error(test_labels, test_predictions)
|
||
logger.info(f"测试集 MSE:{test_mse:.6f}")
|
||
|
||
# 特征重要性
|
||
if hasattr(self.model, 'feature_importances_'):
|
||
importance = self.model.feature_importances_
|
||
logger.debug(f"特征重要性:{dict(zip(train_data.columns, importance))}")
|
||
except Exception as e:
|
||
logger.error(f"FreqAI fit 失败:{str(e)}")
|
||
raise
|