314 lines
12 KiB
Python
314 lines
12 KiB
Python
import logging
|
||
from freqtrade.strategy import IStrategy
|
||
from pandas import DataFrame
|
||
import pandas as pd
|
||
import numpy as np
|
||
import talib as ta
|
||
from typing import Dict, List, Optional
|
||
from sklearn.metrics import mean_squared_error
|
||
from xgboost import XGBRegressor
|
||
import ccxt
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class OKXRegressionStrategy(IStrategy):
|
||
"""
|
||
Freqtrade AI 策略,使用回归模型进行 OKX 数据上的仅做多交易。
|
||
- 数据通过 CCXT 从 OKX 交易所获取。
|
||
- 使用 XGBoost 回归模型预测价格变化。
|
||
- 仅生成做多(买入)信号,不做空。
|
||
- 适配 Freqtrade 2025.3,继承 IStrategy。
|
||
"""
|
||
|
||
# 指标所需的最大启动蜡烛数
|
||
startup_candle_count: int = 20
|
||
|
||
# 策略元数据(建议通过 config.json 配置)
|
||
trailing_stop = True
|
||
trailing_stop_positive = 0.01
|
||
max_open_trades = 3
|
||
stake_amount = 'dynamic'
|
||
|
||
# FreqAI 配置
|
||
freqai_config = {
|
||
"enabled": True,
|
||
"identifier": "okx_regression_v1",
|
||
"model_training_parameters": {
|
||
"n_estimators": 100,
|
||
"learning_rate": 0.05,
|
||
"max_depth": 6
|
||
},
|
||
"feature_parameters": {
|
||
"include_timeframes": ["5m", "15m", "1h"],
|
||
"include_corr_pairlist": ["BTC/USDT", "ETH/USDT"],
|
||
"label_period_candles": 12,
|
||
"include_shifted_candles": 2, # 添加历史偏移特征
|
||
"principal_component_analysis": True # 启用 PCA
|
||
},
|
||
"data_split_parameters": {
|
||
"test_size": 0.2,
|
||
"shuffle": False
|
||
},
|
||
"train_period_days": 90,
|
||
"backtest_period_days": 30,
|
||
"purge_old_models": True # 清理旧模型
|
||
}
|
||
|
||
def __init__(self, config: Dict):
|
||
super().__init__(config)
|
||
# 初始化特征缓存
|
||
self.feature_cache = {}
|
||
|
||
def feature_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: Dict, **kwargs) -> DataFrame:
|
||
"""
|
||
为每个时间框架和相关交易对生成特征。
|
||
"""
|
||
cache_key = f"{metadata.get('pair', 'unknown')}_{period}"
|
||
if cache_key in self.feature_cache:
|
||
logger.debug(f"使用缓存特征:{cache_key}")
|
||
return self.feature_cache[cache_key]
|
||
|
||
# RSI
|
||
dataframe[f"%-%-rsi-{period}"] = ta.RSI(dataframe["close"], timeperiod=period)
|
||
|
||
# MACD
|
||
macd, macdsignal, _ = ta.MACD(dataframe["close"], fastperiod=12, slowperiod=26, signalperiod=9)
|
||
dataframe[f"%-%-macd-{period}"] = macd
|
||
dataframe[f"%-%-macdsignal-{period}"] = macdsignal
|
||
|
||
# 布林带宽度
|
||
upper, middle, lower = ta.BBANDS(dataframe["close"], timeperiod=period)
|
||
dataframe[f"%-%-bb_width-{period}"] = (upper - lower) / middle
|
||
|
||
# 成交量均线
|
||
dataframe[f"%-%-volume_ma-{period}"] = ta.SMA(dataframe["volume"], timeperiod=period)
|
||
|
||
# OKX 订单簿特征(通过 CCXT)
|
||
try:
|
||
order_book = self.fetch_okx_order_book(metadata["pair"])
|
||
dataframe[f"%-%-order_book_imbalance-{period}"] = (
|
||
order_book["bids"] - order_book["asks"]
|
||
) / (order_book["bids"] + order_book["asks"] + 1e-10)
|
||
except Exception as e:
|
||
logger.warning(f"获取 {metadata['pair']} 订单簿失败:{str(e)}")
|
||
dataframe[f"%-%-order_book_imbalance-{period}"] = 0.0
|
||
|
||
# 数据清洗
|
||
dataframe = dataframe.replace([np.inf, -np.inf], np.nan)
|
||
dataframe = dataframe.ffill().fillna(0)
|
||
|
||
# 缓存特征
|
||
self.feature_cache[cache_key] = dataframe.copy()
|
||
logger.debug(f"周期 {period} 特征:{list(dataframe.filter(like='%-%-').columns)}")
|
||
|
||
return dataframe
|
||
|
||
def feature_engineering_standard(self, dataframe: DataFrame, **kwargs) -> DataFrame:
|
||
"""
|
||
添加基础时间框架的全局特征。
|
||
"""
|
||
# 确保索引是 DatetimeIndex
|
||
if not isinstance(dataframe.index, pd.DatetimeIndex):
|
||
dataframe = dataframe.set_index(pd.DatetimeIndex(dataframe.index))
|
||
|
||
# 价格变化率
|
||
dataframe["%-price_change"] = dataframe["close"].pct_change()
|
||
|
||
# 时间特征:小时
|
||
dataframe["%-hour_of_day"] = dataframe.index.hour / 24.0
|
||
|
||
# 数据清洗
|
||
dataframe = dataframe.replace([np.inf, -np.inf], np.nan)
|
||
dataframe = dataframe.ffill().fillna(0)
|
||
|
||
logger.debug(f"全局特征:{list(dataframe.filter(like='%-%-').columns)}")
|
||
return dataframe
|
||
|
||
def set_freqai_targets(self, dataframe: DataFrame, metadata: Dict, **kwargs) -> DataFrame:
|
||
"""
|
||
设置回归模型的目标变量。
|
||
"""
|
||
label_period = self.freqai_config["feature_parameters"]["label_period_candles"]
|
||
# 目标:未来价格变化
|
||
dataframe["&-s_close"] = (
|
||
dataframe["close"].shift(-label_period) - dataframe["close"]
|
||
) / dataframe["close"]
|
||
|
||
# ROI 目标
|
||
for minutes in [0, 15, 30]:
|
||
candles = int(minutes / 5)
|
||
if candles > 0:
|
||
dataframe[f"&-roi_{minutes}"] = (
|
||
dataframe["close"].shift(-candles) - dataframe["close"]
|
||
) / dataframe["close"]
|
||
else:
|
||
dataframe[f"&-roi_{minutes}"] = 0.0
|
||
|
||
# 动态阈值
|
||
dataframe["&-buy_rsi_pred"] = ta.RSI(dataframe["close"], timeperiod=14).rolling(20).mean()
|
||
dataframe["&-stoploss_pred"] = -0.05
|
||
dataframe["&-roi_0_pred"] = 0.03
|
||
|
||
# 数据清洗
|
||
dataframe = dataframe.replace([np.inf, -np.inf], np.nan)
|
||
dataframe = dataframe.fillna(method="ffill").fillna(0)
|
||
|
||
# 验证目标
|
||
required_targets = ["&-s_close", "&-roi_0", "&-buy_rsi_pred", "&-stoploss_pred", "&-roi_0_pred"]
|
||
missing_targets = [col for col in required_targets if col not in dataframe.columns]
|
||
if missing_targets:
|
||
logger.error(f"缺少目标列:{missing_targets}")
|
||
raise ValueError(f"目标初始化失败:{missing_targets}")
|
||
|
||
logger.debug(f"目标初始化完成。DataFrame 形状:{dataframe.shape}")
|
||
return dataframe
|
||
|
||
def populate_indicators(self, dataframe: DataFrame, metadata: Dict) -> DataFrame:
|
||
"""
|
||
使用 FreqAI 生成指标和预测。
|
||
"""
|
||
try:
|
||
logger.debug(f"FreqAI 对象:{type(self.freqai)}")
|
||
dataframe = self.freqai.start(dataframe, metadata, self)
|
||
|
||
# 验证数据完整性
|
||
if dataframe["close"].isna().any() or dataframe["volume"].isna().any():
|
||
logger.warning("检测到 OKX 数据缺失,使用前向填充")
|
||
dataframe = dataframe.ffill().fillna(0)
|
||
|
||
# 预测统计
|
||
if "&-s_close" in dataframe.columns:
|
||
logger.debug(f"预测统计:均值={dataframe['&-s_close'].mean():.4f}, "
|
||
f"方差={dataframe['&-s_close'].var():.4f}")
|
||
|
||
logger.debug(f"生成的列:{list(dataframe.columns)}")
|
||
return dataframe
|
||
except Exception as e:
|
||
logger.error(f"FreqAI start 失败:{str(e)}")
|
||
raise
|
||
finally:
|
||
logger.debug("populate_indicators 完成")
|
||
|
||
# 确保返回 DataFrame,防止 None
|
||
if dataframe is None:
|
||
dataframe = DataFrame()
|
||
return dataframe
|
||
|
||
|
||
def populate_entry_trend(self, dataframe: DataFrame, metadata: Dict) -> DataFrame:
|
||
"""
|
||
基于回归预测生成仅做多信号。
|
||
"""
|
||
# 确保 "%-%-rsi-14" 列存在
|
||
if "%-%-rsi-14" not in dataframe.columns:
|
||
dataframe["%-%-rsi-14"] = 0.0
|
||
|
||
dataframe.loc[
|
||
(
|
||
(dataframe["&-s_close"] > 0.01) & # 预测价格上涨 > 1%
|
||
(dataframe["do_predict"] == 1) & # 预测可靠
|
||
(dataframe["%-%-rsi-14"] < dataframe["&-buy_rsi_pred"]) # RSI 低于动态阈值
|
||
),
|
||
"enter_long"] = 1
|
||
logger.debug(f"生成 {dataframe['enter_long'].sum()} 个做多信号")
|
||
return dataframe
|
||
|
||
def populate_exit_trend(self, dataframe: DataFrame, metadata: Dict) -> DataFrame:
|
||
"""
|
||
基于 ROI 或止损生成退出信号。
|
||
"""
|
||
dataframe.loc[
|
||
(
|
||
(dataframe["&-roi_0_pred"] <= dataframe["close"].pct_change()) | # 达到 ROI
|
||
(dataframe["close"].pct_change() <= dataframe["&-stoploss_pred"]) # 达到止损
|
||
),
|
||
"exit_long"] = 1
|
||
logger.debug(f"生成 {dataframe['exit_long'].sum()} 个退出信号")
|
||
return dataframe
|
||
|
||
def custom_stake_amount(self, pair: str, current_time: 'datetime', current_rate: float,
|
||
proposed_stake: float, min_stake: float, max_stake: float,
|
||
entry_tag: Optional[str], **kwargs) -> float:
|
||
"""
|
||
动态下注:每笔交易占账户余额 2%。
|
||
"""
|
||
balance = self.wallets.get_available_stake_amount()
|
||
stake = balance * 0.02
|
||
return min(max(stake, min_stake), max_stake)
|
||
|
||
def leverage(self, pair: str, current_time: 'datetime', current_rate: float,
|
||
proposed_leverage: float, max_leverage: float, side: str,
|
||
**kwargs) -> float:
|
||
"""
|
||
禁用杠杆,仅做多。
|
||
"""
|
||
return 1.0
|
||
|
||
def confirm_trade_entry(self, pair: str, order_type: str, amount: float, rate: float,
|
||
time_in_force: str, current_time: 'datetime', **kwargs) -> bool:
|
||
"""
|
||
验证交易进入,检查 OKX 数据新鲜度。
|
||
"""
|
||
if not self.check_data_freshness(pair, current_time):
|
||
logger.warning(f"{pair} 的 OKX 数据过期,跳过交易")
|
||
return False
|
||
return True
|
||
|
||
def check_data_freshness(self, pair: str, current_time: 'datetime') -> bool:
|
||
"""
|
||
简化版数据新鲜度检查,不依赖外部 API。
|
||
"""
|
||
# 假设数据总是新鲜的(用于测试)
|
||
return True
|
||
|
||
def fetch_okx_order_book(self, pair: str, limit: int = 10) -> Dict:
|
||
"""
|
||
获取 OKX 订单簿数据。
|
||
"""
|
||
try:
|
||
exchange = ccxt.okx(self.config["exchange"]["ccxt_config"])
|
||
order_book = exchange.fetch_order_book(pair, limit=limit)
|
||
bids = sum([bid[1] for bid in order_book["bids"]])
|
||
asks = sum([ask[1] for ask in order_book["asks"]])
|
||
return {"bids": bids, "asks": asks}
|
||
except Exception as e:
|
||
logger.error(f"获取 {pair} 订单簿失败:{str(e)}")
|
||
return {"bids": 0, "asks": 0}
|
||
|
||
def fit(self, data_dictionary: Dict, metadata: Dict, **kwargs) -> None:
|
||
"""
|
||
训练回归模型并记录性能。
|
||
"""
|
||
try:
|
||
# 初始化模型
|
||
if not hasattr(self, 'model') or self.model is None:
|
||
model_params = self.freqai_config["model_training_parameters"]
|
||
self.model = XGBRegressor(**model_params)
|
||
logger.debug("初始化新的 XGBoost 回归模型")
|
||
|
||
# 调用 FreqAI 训练
|
||
self.freqai.fit(data_dictionary, metadata, **kwargs)
|
||
|
||
# 记录训练集性能
|
||
train_data = data_dictionary["train_features"]
|
||
train_labels = data_dictionary["train_labels"]
|
||
train_predictions = self.model.predict(train_data)
|
||
train_mse = mean_squared_error(train_labels, train_predictions)
|
||
logger.info(f"训练集 MSE:{train_mse:.6f}")
|
||
|
||
# 记录测试集性能(如果可用)
|
||
if "test_features" in data_dictionary:
|
||
test_data = data_dictionary["test_features"]
|
||
test_labels = data_dictionary["test_labels"]
|
||
test_predictions = self.model.predict(test_data)
|
||
test_mse = mean_squared_error(test_labels, test_predictions)
|
||
logger.info(f"测试集 MSE:{test_mse:.6f}")
|
||
|
||
# 特征重要性
|
||
if hasattr(self.model, 'feature_importances_'):
|
||
importance = self.model.feature_importances_
|
||
logger.debug(f"特征重要性:{dict(zip(train_data.columns, importance))}")
|
||
except Exception as e:
|
||
logger.error(f"FreqAI fit 失败:{str(e)}")
|
||
raise
|