myTestFreqAI/result/backtest-result-2025-05-04_10-29-18_OKXRegressionStrategy.py
zhangkun9038@dingtalk.com 1dc550f347 胜率35%
2025-05-04 18:31:12 +08:00

426 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from freqtrade.strategy import IStrategy
from pandas import DataFrame
import pandas as pd
import numpy as np
import talib as ta
import datetime
from typing import Dict, List, Optional
from sklearn.metrics import mean_squared_error
from freqtrade.strategy import CategoricalParameter, DecimalParameter
from xgboost import XGBRegressor
import ccxt
logger = logging.getLogger(__name__)
class OKXRegressionStrategy(IStrategy):
"""
Freqtrade AI 策略,使用回归模型进行 OKX 数据上的仅做多交易。
- 数据通过 CCXT 从 OKX 交易所获取。
- 使用 XGBoost 回归模型预测价格变化。
- 仅生成做多(买入)信号,不做空。
- 适配 Freqtrade 2025.3,继承 IStrategy。
"""
# 指标所需的最大启动蜡烛数
startup_candle_count: int = 20
# 策略元数据(建议通过 config.json 配置)
trailing_stop = True
trailing_stop_positive = 0.01
max_open_trades = 3
stake_amount = 'dynamic'
atr_period = CategoricalParameter([7, 14, 21], default=14, space='buy')
atr_multiplier = DecimalParameter(1.0, 3.0, default=2.0, space='sell')
# FreqAI 配置
freqai_config = {
"enabled": True,
"identifier": "okx_regression_v1",
"model_training_parameters": {
"n_estimators": 100,
"learning_rate": 0.05,
"max_depth": 6
},
"feature_parameters": {
"include_timeframes": ["5m", "15m", "1h"],
"include_corr_pairlist": ["BTC/USDT", "ETH/USDT"],
"label_period_candles": 12,
"include_shifted_candles": 2, # 添加历史偏移特征
"principal_component_analysis": True # 启用 PCA
},
"data_split_parameters": {
"test_size": 0.2,
"random_state": 42,
"shuffle": False
},
"train_period_days": 90,
"backtest_period_days": 30,
"purge_old_models": True # 清理旧模型
}
def __init__(self, config: Dict):
super().__init__(config)
# 初始化特征缓存
self.feature_cache = {}
def feature_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: Dict, **kwargs) -> DataFrame:
"""
为每个时间框架和相关交易对生成特征。
"""
cache_key = f"{metadata.get('pair', 'unknown')}_{period}"
if cache_key in self.feature_cache:
logger.debug(f"使用缓存特征:{cache_key}")
return self.feature_cache[cache_key]
# RSI
dataframe[f"%-%-rsi-{period}"] = ta.RSI(dataframe["close"], timeperiod=period)
# MACD
macd, macdsignal, _ = ta.MACD(dataframe["close"], fastperiod=12, slowperiod=26, signalperiod=9)
dataframe[f"%-%-macd-{period}"] = macd
dataframe[f"%-%-macdsignal-{period}"] = macdsignal
# 布林带宽度
upper, middle, lower = ta.BBANDS(dataframe["close"], timeperiod=period)
dataframe[f"%-%-bb_width-{period}"] = (upper - lower) / middle
# 成交量均线
dataframe[f"%-%-volume_ma-{period}"] = ta.SMA(dataframe["volume"], timeperiod=period)
# 仅为 BTC/USDT 和 ETH/USDT 生成 order_book_imbalance
#
pair = metadata.get('pair', 'unknown')
# 注释掉订单簿相关代码
# if pair in ["BTC/USDT", "ETH/USDT"]:
# try:
# order_book = self.fetch_okx_order_book(pair)
# dataframe[f"%-%-order_book_imbalance"] = (
# order_book["bids"] - order_book["asks"]
# ) / (order_book["bids"] + order_book["asks"] + 1e-10)
# except Exception as e:
# logger.warning(f"Failed to fetch order book for {pair}: {str(e)}")
# dataframe[f"%-%-order_book_imbalance"] = 0.0
# 数据清洗
dataframe = dataframe.replace([np.inf, -np.inf], np.nan)
dataframe = dataframe.ffill().fillna(0)
# 缓存特征
self.feature_cache[cache_key] = dataframe.copy()
logger.debug(f"周期 {period} 特征:{list(dataframe.filter(like='%-%-').columns)}")
return dataframe
def feature_engineering_standard(self, dataframe: DataFrame, **kwargs) -> DataFrame:
"""
添加基础时间框架的全局特征。
"""
# 确保索引是 DatetimeIndex
if not isinstance(dataframe.index, pd.DatetimeIndex):
dataframe = dataframe.set_index(pd.DatetimeIndex(dataframe.index))
# 价格变化率
dataframe["%-price_change"] = dataframe["close"].pct_change()
# 时间特征:小时
dataframe["%-hour_of_day"] = dataframe.index.hour / 24.0
# 数据清洗
dataframe = dataframe.replace([np.inf, -np.inf], np.nan)
dataframe = dataframe.ffill().fillna(0)
logger.debug(f"全局特征:{list(dataframe.filter(like='%-%-').columns)}")
return dataframe
def set_freqai_targets(self, dataframe: DataFrame, metadata: Dict, **kwargs) -> DataFrame:
"""
设置回归模型的目标变量为不同币对设置动态止损和ROI阈值。
输入dataframeK线数据close, high, lowmetadata交易对信息configFreqAI配置
输出更新后的dataframe包含目标标签
"""
# 获取配置参数
label_period = self.freqai_config["feature_parameters"]["label_period_candles"] # 标签预测周期如5分钟K线的N根
pair = metadata["pair"] # 当前交易对如DOGE/USDT
# 计算未来价格变化率(现有逻辑)
dataframe["&-s_close"] = (dataframe["close"].shift(-label_period) - dataframe["close"]) / dataframe["close"]
# 计算不同时间窗口的ROI现有逻辑
for minutes in [0, 15, 30]:
candles = int(minutes / 5) # 假设5分钟K线
if candles > 0:
dataframe[f"&-roi_{minutes}"] = (dataframe["close"].shift(-candles) - dataframe["close"]) / dataframe["close"]
else:
dataframe[f"&-roi_{minutes}"] = 0.0
# 计算市场状态指标ADX14周期与label_period_candles对齐
dataframe["adx"] = ta.ADX(dataframe["high"], dataframe["low"], dataframe["close"], timeperiod=14)
# 定义币对特定的ADX阈值和止损/ROI范围
pair_thresholds = {
"DOGE/USDT": {
"adx_trend": 20, # 趋势市场ADX阈值
"adx_oscillation": 15, # 震荡市场ADX阈值
"stoploss_trend": -0.08, # 趋势市场止损:-8%
"stoploss_oscillation": -0.04, # 震荡市场止损:-4%
"stoploss_mid": -0.06, # 中间状态止损:-6%
"roi_trend": 0.06, # 趋势市场ROI6%
"roi_oscillation": 0.025, # 震荡市场ROI2.5%
"roi_mid": 0.04 # 中间状态ROI4%
},
"BTC/USDT": {
"adx_trend": 25,
"adx_oscillation": 20,
"stoploss_trend": -0.03,
"stoploss_oscillation": -0.015,
"stoploss_mid": -0.02,
"roi_trend": 0.03,
"roi_oscillation": 0.015,
"roi_mid": 0.02
},
"SOL/USDT": {
"adx_trend": 22,
"adx_oscillation": 18,
"stoploss_trend": -0.06,
"stoploss_oscillation": -0.03,
"stoploss_mid": -0.045,
"roi_trend": 0.045,
"roi_oscillation": 0.02,
"roi_mid": 0.03
},
"XRP/USDT": {
"adx_trend": 22,
"adx_oscillation": 18,
"stoploss_trend": -0.06,
"stoploss_oscillation": -0.03,
"stoploss_mid": -0.045,
"roi_trend": 0.045,
"roi_oscillation": 0.02,
"roi_mid": 0.03
}
}
# 动态化 &-stoploss_pred基于市场状态和币对
dataframe["&-stoploss_pred"] = 0.0
for index, row in dataframe.iterrows():
thresholds = pair_thresholds.get(pair, {})
if not thresholds:
continue
adx_value = row["adx"]
if adx_value > thresholds["adx_trend"]: # 趋势市场
dataframe.at[index, "&-stoploss_pred"] = thresholds["stoploss_trend"] # 宽松止损
elif adx_value < thresholds["adx_oscillation"]: # 震荡市场
dataframe.at[index, "&-stoploss_pred"] = thresholds["stoploss_oscillation"] # 严格止损
else: # 中间状态
dataframe.at[index, "&-stoploss_pred"] = thresholds["stoploss_mid"] # 中等止损
# 风险控制:设置止损下限
if dataframe.at[index, "&-stoploss_pred"] < -0.10:
dataframe.at[index, "&-stoploss_pred"] = -0.10
# 动态化 &-roi_0_pred基于市场趋势和币对
dataframe["&-roi_0_pred"] = 0.0
for index, row in dataframe.iterrows():
thresholds = pair_thresholds.get(pair, {})
if not thresholds:
continue
adx_value = row["adx"]
if adx_value > thresholds["adx_trend"]: # 强趋势市场
dataframe.at[index, "&-roi_0_pred"] = thresholds["roi_trend"] # 高ROI
elif adx_value < thresholds["adx_oscillation"]: # 震荡市场
dataframe.at[index, "&-roi_0_pred"] = thresholds["roi_oscillation"] # 低ROI
else: # 中间状态
dataframe.at[index, "&-roi_0_pred"] = thresholds["roi_mid"] # 中等ROI
# 风险控制设置ROI上限
if dataframe.at[index, "&-roi_0_pred"] > 0.10:
dataframe.at[index, "&-roi_0_pred"] = 0.10
# 计算RSI预测现有逻辑
dataframe["&-buy_rsi_pred"] = ta.RSI(dataframe["close"], timeperiod=14).rolling(20).mean()
# 数据清洗
dataframe = dataframe.replace([np.inf, -np.inf], np.nan)
dataframe = dataframe.fillna(method="ffill").fillna(0)
# 验证目标
required_targets = ["&-s_close", "&-roi_0", "&-buy_rsi_pred", "&-stoploss_pred", "&-roi_0_pred"]
missing_targets = [col for col in required_targets if col not in dataframe.columns]
if missing_targets:
logger.error(f"缺少目标列:{missing_targets}")
raise ValueError(f"目标初始化失败:{missing_targets}")
logger.debug(f"目标初始化完成。DataFrame 形状:{dataframe.shape}")
return dataframe
def populate_indicators(self, dataframe: DataFrame, metadata: Dict) -> DataFrame:
"""
使用 FreqAI 生成指标和预测。
"""
try:
logger.debug(f"FreqAI 对象:{type(self.freqai)}")
dataframe = self.freqai.start(dataframe, metadata, self)
# 验证数据完整性
if dataframe["close"].isna().any() or dataframe["volume"].isna().any():
logger.warning("检测到 OKX 数据缺失,使用前向填充")
dataframe = dataframe.ffill().fillna(0)
# 预测统计
if "&-s_close" in dataframe.columns:
logger.debug(f"预测统计:均值={dataframe['&-s_close'].mean():.4f}, "
f"方差={dataframe['&-s_close'].var():.4f}")
logger.debug(f"生成的列:{list(dataframe.columns)}")
return dataframe
except Exception as e:
logger.error(f"FreqAI start 失败:{str(e)}")
raise
finally:
logger.debug("populate_indicators 完成")
# 确保返回 DataFrame防止 None
if dataframe is None:
dataframe = DataFrame()
dataframe['ATR_{}'.format(self.atr_period.value)] = ta.ATR(dataframe['high'], dataframe['low'], dataframe['close'], timeperiod=self.atr_period.value)
return dataframe
def populate_entry_trend(self, dataframe: DataFrame, metadata: Dict) -> DataFrame:
"""
基于回归预测生成仅做多信号。
"""
# 确保 "%-%-rsi-14" 列存在
if "%-%-rsi-14" not in dataframe.columns:
dataframe["%-%-rsi-14"] = 0.0
dataframe.loc[
(
(dataframe["&-s_close"] > 0.01) & # 预测价格上涨 > 1%
(dataframe["do_predict"] == 1) & # 预测可靠
(dataframe["%-%-rsi-14"] < dataframe["&-buy_rsi_pred"]) # RSI 低于动态阈值
),
"enter_long"] = 1
# 设置 entry_price 列,用于止损逻辑
dataframe['entry_price'] = dataframe['open'].where(dataframe['enter_long'] == 1).ffill()
logger.debug(f"生成 {dataframe['enter_long'].sum()} 个做多信号")
return dataframe
def _dynamic_stop_loss(self, dataframe: DataFrame, metadata: dict, atr_col: str = 'ATR_14', multiplier: float = 2.0) -> DataFrame:
"""
封装动态止损逻辑基于入场价和ATR计算止损线
:param dataframe: 原始DataFrame
:param metadata: 策略元数据
:param atr_col: 使用的ATR列名
:param multiplier: ATR乘数
:return: 更新后的DataFrame
"""
dataframe['entry_price'] = dataframe['open'].where(dataframe['enter_long'] == 1).ffill()
dataframe['stop_loss_line'] = dataframe['entry_price'] - dataframe[atr_col] * multiplier
dataframe.loc[
(dataframe['close'] < dataframe['stop_loss_line']),
'exit_long'
] = 1
return dataframe
def populate_exit_trend(self, dataframe: DataFrame, metadata: Dict) -> DataFrame:
# 确保 ATR 列存在
if 'ATR_14' not in dataframe.columns:
dataframe['ATR_14'] = 0.0
# 计算动态止损线
dataframe['stop_loss_line'] = dataframe['entry_price'] - (dataframe['ATR_14'] * 2)
# 发送止损信息
self.dp.send_msg(f"ATR: {dataframe['ATR_14'].iloc[-1]:.5f}, Stop Loss Line: {dataframe['stop_loss_line'].iloc[-1]:.5f}")
# 应用动态止损逻辑
return self._dynamic_stop_loss(dataframe, metadata)
def custom_stake_amount(self, pair: str, current_time: 'datetime', current_rate: float,
proposed_stake: float, min_stake: float, max_stake: float,
entry_tag: Optional[str], **kwargs) -> float:
"""
动态下注:每笔交易占账户余额 2%
"""
balance = self.wallets.get_available_stake_amount()
stake = balance * 0.02
return min(max(stake, min_stake), max_stake)
def custom_stoploss(self, pair: str, trade: 'Trade', current_time: datetime,
current_rate: float, profit_percent: float,
after_fill: bool, **kwargs) -> Optional[float]:
if trade.enter_tag == 'long':
atr_value = self.dp.get_pair_dataframe(pair, timeframe=self.timeframe)['ATR_14'].iloc[-1]
trailing_stop = current_rate - atr_value * 1.5
return trailing_stop / current_rate - 1 # 返回相对百分比
return None
def leverage(self, pair: str, current_time: 'datetime', current_rate: float,
proposed_leverage: float, max_leverage: float, side: str,
**kwargs) -> float:
"""
禁用杠杆,仅做多。
"""
return 1.0
def confirm_trade_entry(self, pair: str, order_type: str, amount: float, rate: float,
time_in_force: str, current_time: 'datetime', **kwargs) -> bool:
"""
验证交易进入,检查 OKX 数据新鲜度。
"""
if not self.check_data_freshness(pair, current_time):
logger.warning(f"{pair} 的 OKX 数据过期,跳过交易")
return False
return True
def check_data_freshness(self, pair: str, current_time: 'datetime') -> bool:
"""
简化版数据新鲜度检查,不依赖外部 API。
"""
# 假设数据总是新鲜的(用于测试)
return True
def fit(self, data_dictionary: Dict, metadata: Dict, **kwargs) -> None:
"""
训练回归模型并记录性能。
"""
try:
# 初始化模型
if not hasattr(self, 'model') or self.model is None:
model_params = self.freqai_config["model_training_parameters"]
self.model = XGBRegressor(**model_params)
logger.debug("初始化新的 XGBoost 回归模型")
# 调用 FreqAI 训练
self.freqai.fit(data_dictionary, metadata, **kwargs)
# 记录训练集性能
train_data = data_dictionary["train_features"]
train_labels = data_dictionary["train_labels"]
train_predictions = self.model.predict(train_data)
train_mse = mean_squared_error(train_labels, train_predictions)
logger.info(f"训练集 MSE{train_mse:.6f}")
# 记录测试集性能(如果可用)
if "test_features" in data_dictionary:
test_data = data_dictionary["test_features"]
test_labels = data_dictionary["test_labels"]
test_predictions = self.model.predict(test_data)
test_mse = mean_squared_error(test_labels, test_predictions)
logger.info(f"测试集 MSE{test_mse:.6f}")
# 特征重要性
if hasattr(self.model, 'feature_importances_'):
importance = self.model.feature_importances_
logger.debug(f"特征重要性:{dict(zip(train_data.columns, importance))}")
except Exception as e:
logger.error(f"FreqAI fit 失败:{str(e)}")
raise