这个okx策略有潜力, 输的不多
This commit is contained in:
parent
6570c6da6c
commit
590baedbb1
@ -67,8 +67,8 @@ services:
|
||||
backtesting
|
||||
--logfile /freqtrade/user_data/logs/freqtrade.log
|
||||
--freqaimodel XGBoostRegressor
|
||||
--config /freqtrade/config_examples/aienhance_config.json
|
||||
--config /freqtrade/config_examples/config_freqai.okx.json
|
||||
--strategy-path /freqtrade/templates
|
||||
--strategy AIEnhancedStrategy
|
||||
--timerange 20250401-20250420
|
||||
--strategy OKXRegressionStrategy
|
||||
--timerange 20250301-20250420
|
||||
--cache none
|
||||
|
||||
313
freqtrade/templates/OKXRegressionStrategy.py
Normal file
313
freqtrade/templates/OKXRegressionStrategy.py
Normal file
@ -0,0 +1,313 @@
|
||||
import logging
|
||||
from freqtrade.strategy import IStrategy
|
||||
from pandas import DataFrame
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import talib as ta
|
||||
from typing import Dict, List, Optional
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from xgboost import XGBRegressor
|
||||
import ccxt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OKXRegressionStrategy(IStrategy):
|
||||
"""
|
||||
Freqtrade AI 策略,使用回归模型进行 OKX 数据上的仅做多交易。
|
||||
- 数据通过 CCXT 从 OKX 交易所获取。
|
||||
- 使用 XGBoost 回归模型预测价格变化。
|
||||
- 仅生成做多(买入)信号,不做空。
|
||||
- 适配 Freqtrade 2025.3,继承 IStrategy。
|
||||
"""
|
||||
|
||||
# 指标所需的最大启动蜡烛数
|
||||
startup_candle_count: int = 20
|
||||
|
||||
# 策略元数据(建议通过 config.json 配置)
|
||||
trailing_stop = True
|
||||
trailing_stop_positive = 0.01
|
||||
max_open_trades = 3
|
||||
stake_amount = 'dynamic'
|
||||
|
||||
# FreqAI 配置
|
||||
freqai_config = {
|
||||
"enabled": True,
|
||||
"identifier": "okx_regression_v1",
|
||||
"model_training_parameters": {
|
||||
"n_estimators": 100,
|
||||
"learning_rate": 0.05,
|
||||
"max_depth": 6
|
||||
},
|
||||
"feature_parameters": {
|
||||
"include_timeframes": ["5m", "15m", "1h"],
|
||||
"include_corr_pairlist": ["BTC/USDT", "ETH/USDT"],
|
||||
"label_period_candles": 12,
|
||||
"include_shifted_candles": 2, # 添加历史偏移特征
|
||||
"principal_component_analysis": True # 启用 PCA
|
||||
},
|
||||
"data_split_parameters": {
|
||||
"test_size": 0.2,
|
||||
"shuffle": False
|
||||
},
|
||||
"train_period_days": 90,
|
||||
"backtest_period_days": 30,
|
||||
"purge_old_models": True # 清理旧模型
|
||||
}
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
super().__init__(config)
|
||||
# 初始化特征缓存
|
||||
self.feature_cache = {}
|
||||
|
||||
def feature_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: Dict, **kwargs) -> DataFrame:
|
||||
"""
|
||||
为每个时间框架和相关交易对生成特征。
|
||||
"""
|
||||
cache_key = f"{metadata.get('pair', 'unknown')}_{period}"
|
||||
if cache_key in self.feature_cache:
|
||||
logger.debug(f"使用缓存特征:{cache_key}")
|
||||
return self.feature_cache[cache_key]
|
||||
|
||||
# RSI
|
||||
dataframe[f"%-%-rsi-{period}"] = ta.RSI(dataframe["close"], timeperiod=period)
|
||||
|
||||
# MACD
|
||||
macd, macdsignal, _ = ta.MACD(dataframe["close"], fastperiod=12, slowperiod=26, signalperiod=9)
|
||||
dataframe[f"%-%-macd-{period}"] = macd
|
||||
dataframe[f"%-%-macdsignal-{period}"] = macdsignal
|
||||
|
||||
# 布林带宽度
|
||||
upper, middle, lower = ta.BBANDS(dataframe["close"], timeperiod=period)
|
||||
dataframe[f"%-%-bb_width-{period}"] = (upper - lower) / middle
|
||||
|
||||
# 成交量均线
|
||||
dataframe[f"%-%-volume_ma-{period}"] = ta.SMA(dataframe["volume"], timeperiod=period)
|
||||
|
||||
# OKX 订单簿特征(通过 CCXT)
|
||||
try:
|
||||
order_book = self.fetch_okx_order_book(metadata["pair"])
|
||||
dataframe[f"%-%-order_book_imbalance-{period}"] = (
|
||||
order_book["bids"] - order_book["asks"]
|
||||
) / (order_book["bids"] + order_book["asks"] + 1e-10)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取 {metadata['pair']} 订单簿失败:{str(e)}")
|
||||
dataframe[f"%-%-order_book_imbalance-{period}"] = 0.0
|
||||
|
||||
# 数据清洗
|
||||
dataframe = dataframe.replace([np.inf, -np.inf], np.nan)
|
||||
dataframe = dataframe.ffill().fillna(0)
|
||||
|
||||
# 缓存特征
|
||||
self.feature_cache[cache_key] = dataframe.copy()
|
||||
logger.debug(f"周期 {period} 特征:{list(dataframe.filter(like='%-%-').columns)}")
|
||||
|
||||
return dataframe
|
||||
|
||||
def feature_engineering_standard(self, dataframe: DataFrame, **kwargs) -> DataFrame:
|
||||
"""
|
||||
添加基础时间框架的全局特征。
|
||||
"""
|
||||
# 确保索引是 DatetimeIndex
|
||||
if not isinstance(dataframe.index, pd.DatetimeIndex):
|
||||
dataframe = dataframe.set_index(pd.DatetimeIndex(dataframe.index))
|
||||
|
||||
# 价格变化率
|
||||
dataframe["%-price_change"] = dataframe["close"].pct_change()
|
||||
|
||||
# 时间特征:小时
|
||||
dataframe["%-hour_of_day"] = dataframe.index.hour / 24.0
|
||||
|
||||
# 数据清洗
|
||||
dataframe = dataframe.replace([np.inf, -np.inf], np.nan)
|
||||
dataframe = dataframe.ffill().fillna(0)
|
||||
|
||||
logger.debug(f"全局特征:{list(dataframe.filter(like='%-%-').columns)}")
|
||||
return dataframe
|
||||
|
||||
def set_freqai_targets(self, dataframe: DataFrame, metadata: Dict, **kwargs) -> DataFrame:
|
||||
"""
|
||||
设置回归模型的目标变量。
|
||||
"""
|
||||
label_period = self.freqai_config["feature_parameters"]["label_period_candles"]
|
||||
# 目标:未来价格变化
|
||||
dataframe["&-s_close"] = (
|
||||
dataframe["close"].shift(-label_period) - dataframe["close"]
|
||||
) / dataframe["close"]
|
||||
|
||||
# ROI 目标
|
||||
for minutes in [0, 15, 30]:
|
||||
candles = int(minutes / 5)
|
||||
if candles > 0:
|
||||
dataframe[f"&-roi_{minutes}"] = (
|
||||
dataframe["close"].shift(-candles) - dataframe["close"]
|
||||
) / dataframe["close"]
|
||||
else:
|
||||
dataframe[f"&-roi_{minutes}"] = 0.0
|
||||
|
||||
# 动态阈值
|
||||
dataframe["&-buy_rsi_pred"] = ta.RSI(dataframe["close"], timeperiod=14).rolling(20).mean()
|
||||
dataframe["&-stoploss_pred"] = -0.05
|
||||
dataframe["&-roi_0_pred"] = 0.03
|
||||
|
||||
# 数据清洗
|
||||
dataframe = dataframe.replace([np.inf, -np.inf], np.nan)
|
||||
dataframe = dataframe.fillna(method="ffill").fillna(0)
|
||||
|
||||
# 验证目标
|
||||
required_targets = ["&-s_close", "&-roi_0", "&-buy_rsi_pred", "&-stoploss_pred", "&-roi_0_pred"]
|
||||
missing_targets = [col for col in required_targets if col not in dataframe.columns]
|
||||
if missing_targets:
|
||||
logger.error(f"缺少目标列:{missing_targets}")
|
||||
raise ValueError(f"目标初始化失败:{missing_targets}")
|
||||
|
||||
logger.debug(f"目标初始化完成。DataFrame 形状:{dataframe.shape}")
|
||||
return dataframe
|
||||
|
||||
def populate_indicators(self, dataframe: DataFrame, metadata: Dict) -> DataFrame:
|
||||
"""
|
||||
使用 FreqAI 生成指标和预测。
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"FreqAI 对象:{type(self.freqai)}")
|
||||
dataframe = self.freqai.start(dataframe, metadata, self)
|
||||
|
||||
# 验证数据完整性
|
||||
if dataframe["close"].isna().any() or dataframe["volume"].isna().any():
|
||||
logger.warning("检测到 OKX 数据缺失,使用前向填充")
|
||||
dataframe = dataframe.ffill().fillna(0)
|
||||
|
||||
# 预测统计
|
||||
if "&-s_close" in dataframe.columns:
|
||||
logger.debug(f"预测统计:均值={dataframe['&-s_close'].mean():.4f}, "
|
||||
f"方差={dataframe['&-s_close'].var():.4f}")
|
||||
|
||||
logger.debug(f"生成的列:{list(dataframe.columns)}")
|
||||
return dataframe
|
||||
except Exception as e:
|
||||
logger.error(f"FreqAI start 失败:{str(e)}")
|
||||
raise
|
||||
finally:
|
||||
logger.debug("populate_indicators 完成")
|
||||
|
||||
# 确保返回 DataFrame,防止 None
|
||||
if dataframe is None:
|
||||
dataframe = DataFrame()
|
||||
return dataframe
|
||||
|
||||
|
||||
def populate_entry_trend(self, dataframe: DataFrame, metadata: Dict) -> DataFrame:
|
||||
"""
|
||||
基于回归预测生成仅做多信号。
|
||||
"""
|
||||
# 确保 "%-%-rsi-14" 列存在
|
||||
if "%-%-rsi-14" not in dataframe.columns:
|
||||
dataframe["%-%-rsi-14"] = 0.0
|
||||
|
||||
dataframe.loc[
|
||||
(
|
||||
(dataframe["&-s_close"] > 0.01) & # 预测价格上涨 > 1%
|
||||
(dataframe["do_predict"] == 1) & # 预测可靠
|
||||
(dataframe["%-%-rsi-14"] < dataframe["&-buy_rsi_pred"]) # RSI 低于动态阈值
|
||||
),
|
||||
"enter_long"] = 1
|
||||
logger.debug(f"生成 {dataframe['enter_long'].sum()} 个做多信号")
|
||||
return dataframe
|
||||
|
||||
def populate_exit_trend(self, dataframe: DataFrame, metadata: Dict) -> DataFrame:
|
||||
"""
|
||||
基于 ROI 或止损生成退出信号。
|
||||
"""
|
||||
dataframe.loc[
|
||||
(
|
||||
(dataframe["&-roi_0_pred"] <= dataframe["close"].pct_change()) | # 达到 ROI
|
||||
(dataframe["close"].pct_change() <= dataframe["&-stoploss_pred"]) # 达到止损
|
||||
),
|
||||
"exit_long"] = 1
|
||||
logger.debug(f"生成 {dataframe['exit_long'].sum()} 个退出信号")
|
||||
return dataframe
|
||||
|
||||
def custom_stake_amount(self, pair: str, current_time: 'datetime', current_rate: float,
|
||||
proposed_stake: float, min_stake: float, max_stake: float,
|
||||
entry_tag: Optional[str], **kwargs) -> float:
|
||||
"""
|
||||
动态下注:每笔交易占账户余额 2%。
|
||||
"""
|
||||
balance = self.wallets.get_available_stake_amount()
|
||||
stake = balance * 0.02
|
||||
return min(max(stake, min_stake), max_stake)
|
||||
|
||||
def leverage(self, pair: str, current_time: 'datetime', current_rate: float,
|
||||
proposed_leverage: float, max_leverage: float, side: str,
|
||||
**kwargs) -> float:
|
||||
"""
|
||||
禁用杠杆,仅做多。
|
||||
"""
|
||||
return 1.0
|
||||
|
||||
def confirm_trade_entry(self, pair: str, order_type: str, amount: float, rate: float,
|
||||
time_in_force: str, current_time: 'datetime', **kwargs) -> bool:
|
||||
"""
|
||||
验证交易进入,检查 OKX 数据新鲜度。
|
||||
"""
|
||||
if not self.check_data_freshness(pair, current_time):
|
||||
logger.warning(f"{pair} 的 OKX 数据过期,跳过交易")
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_data_freshness(self, pair: str, current_time: 'datetime') -> bool:
|
||||
"""
|
||||
简化版数据新鲜度检查,不依赖外部 API。
|
||||
"""
|
||||
# 假设数据总是新鲜的(用于测试)
|
||||
return True
|
||||
|
||||
def fetch_okx_order_book(self, pair: str, limit: int = 10) -> Dict:
|
||||
"""
|
||||
获取 OKX 订单簿数据。
|
||||
"""
|
||||
try:
|
||||
exchange = ccxt.okx(self.config["exchange"]["ccxt_config"])
|
||||
order_book = exchange.fetch_order_book(pair, limit=limit)
|
||||
bids = sum([bid[1] for bid in order_book["bids"]])
|
||||
asks = sum([ask[1] for ask in order_book["asks"]])
|
||||
return {"bids": bids, "asks": asks}
|
||||
except Exception as e:
|
||||
logger.error(f"获取 {pair} 订单簿失败:{str(e)}")
|
||||
return {"bids": 0, "asks": 0}
|
||||
|
||||
def fit(self, data_dictionary: Dict, metadata: Dict, **kwargs) -> None:
|
||||
"""
|
||||
训练回归模型并记录性能。
|
||||
"""
|
||||
try:
|
||||
# 初始化模型
|
||||
if not hasattr(self, 'model') or self.model is None:
|
||||
model_params = self.freqai_config["model_training_parameters"]
|
||||
self.model = XGBRegressor(**model_params)
|
||||
logger.debug("初始化新的 XGBoost 回归模型")
|
||||
|
||||
# 调用 FreqAI 训练
|
||||
self.freqai.fit(data_dictionary, metadata, **kwargs)
|
||||
|
||||
# 记录训练集性能
|
||||
train_data = data_dictionary["train_features"]
|
||||
train_labels = data_dictionary["train_labels"]
|
||||
train_predictions = self.model.predict(train_data)
|
||||
train_mse = mean_squared_error(train_labels, train_predictions)
|
||||
logger.info(f"训练集 MSE:{train_mse:.6f}")
|
||||
|
||||
# 记录测试集性能(如果可用)
|
||||
if "test_features" in data_dictionary:
|
||||
test_data = data_dictionary["test_features"]
|
||||
test_labels = data_dictionary["test_labels"]
|
||||
test_predictions = self.model.predict(test_data)
|
||||
test_mse = mean_squared_error(test_labels, test_predictions)
|
||||
logger.info(f"测试集 MSE:{test_mse:.6f}")
|
||||
|
||||
# 特征重要性
|
||||
if hasattr(self.model, 'feature_importances_'):
|
||||
importance = self.model.feature_importances_
|
||||
logger.debug(f"特征重要性:{dict(zip(train_data.columns, importance))}")
|
||||
except Exception as e:
|
||||
logger.error(f"FreqAI fit 失败:{str(e)}")
|
||||
raise
|
||||
0
freqtrade/templates/huigui.py
Normal file
0
freqtrade/templates/huigui.py
Normal file
1807
output.log
1807
output.log
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@ -1 +0,0 @@
|
||||
{"FreqaiExampleStrategy":{"run_id":"cfc214dcdc7f4207a8a7b44e1b0be860672291d8","backtest_start_time":1746088326,"timeframe":"3m","timeframe_detail":null,"backtest_start_ts":1743465600,"backtest_end_ts":1745107200}}
|
||||
@ -1,32 +0,0 @@
|
||||
{
|
||||
"strategy_name": "FreqaiExampleStrategy",
|
||||
"params": {
|
||||
"trailing": {
|
||||
"trailing_stop": true,
|
||||
"trailing_stop_positive": 0.01,
|
||||
"trailing_stop_positive_offset": 0.02,
|
||||
"trailing_only_offset_is_reached": false
|
||||
},
|
||||
"max_open_trades": {
|
||||
"max_open_trades": 4
|
||||
},
|
||||
"buy": {
|
||||
"buy_rsi": 39.92672300850069
|
||||
},
|
||||
"sell": {
|
||||
"sell_rsi": 69.92672300850067
|
||||
},
|
||||
"protection": {},
|
||||
"roi": {
|
||||
"0": 0.132,
|
||||
"8": 0.047,
|
||||
"14": 0.007,
|
||||
"60": 0
|
||||
},
|
||||
"stoploss": {
|
||||
"stoploss": -0.322
|
||||
}
|
||||
},
|
||||
"ft_stratparam_v": 1,
|
||||
"export_time": "2025-04-23 12:30:05.550433+00:00"
|
||||
}
|
||||
@ -1,336 +0,0 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
from functools import reduce
|
||||
import talib.abstract as ta
|
||||
from pandas import DataFrame
|
||||
from technical import qtpylib
|
||||
from freqtrade.strategy import IStrategy, IntParameter, DecimalParameter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FreqaiExampleStrategy(IStrategy):
|
||||
# 移除硬编码的 minimal_roi 和 stoploss,改为动态适配
|
||||
minimal_roi = {} # 将在 populate_indicators 中动态生成
|
||||
stoploss = 0.0 # 将在 populate_indicators 中动态设置
|
||||
trailing_stop = True
|
||||
process_only_new_candles = True
|
||||
use_exit_signal = True
|
||||
startup_candle_count: int = 40
|
||||
can_short = False
|
||||
|
||||
# 参数定义:FreqAI 动态适配 buy_rsi 和 sell_rsi,禁用 Hyperopt 优化
|
||||
buy_rsi = IntParameter(low=10, high=50, default=27, space="buy", optimize=False, load=True)
|
||||
sell_rsi = IntParameter(low=50, high=90, default=59, space="sell", optimize=False, load=True)
|
||||
|
||||
# 为 Hyperopt 优化添加 ROI 和 stoploss 参数
|
||||
roi_0 = DecimalParameter(low=0.01, high=0.2, default=0.038, space="roi", optimize=True, load=True)
|
||||
roi_15 = DecimalParameter(low=0.005, high=0.1, default=0.027, space="roi", optimize=True, load=True)
|
||||
roi_30 = DecimalParameter(low=0.001, high=0.05, default=0.009, space="roi", optimize=True, load=True)
|
||||
stoploss_param = DecimalParameter(low=-0.35, high=-0.1, default=-0.182, space="stoploss", optimize=True, load=True)
|
||||
|
||||
# FreqAI 配置
|
||||
freqai_info = {
|
||||
"model": "XGBoostRegressor", # 与config保持一致
|
||||
"feature_parameters": {
|
||||
"include_timeframes": ["3m", "15m", "1h"], # 与config一致
|
||||
"include_corr_pairlist": ["BTC/USDT", "SOL/USDT"], # 添加相关交易对
|
||||
"label_period_candles": 20, # 与config一致
|
||||
"include_shifted_candles": 2, # 与config一致
|
||||
},
|
||||
"data_split_parameters": {
|
||||
"test_size": 0.2,
|
||||
"shuffle": True, # 启用shuffle
|
||||
},
|
||||
"model_training_parameters": {
|
||||
"n_estimators": 100, # 减少树的数量
|
||||
"learning_rate": 0.1, # 提高学习率
|
||||
"max_depth": 6, # 限制树深度
|
||||
"subsample": 0.8, # 添加子采样
|
||||
"colsample_bytree": 0.8, # 添加特征采样
|
||||
"objective": "reg:squarederror",
|
||||
"eval_metric": "rmse",
|
||||
"early_stopping_rounds": 20,
|
||||
"verbose": 0,
|
||||
},
|
||||
"data_kitchen": {
|
||||
"feature_parameters": {
|
||||
"DI_threshold": 1.5, # 降低异常值过滤阈值
|
||||
"use_DBSCAN_to_remove_outliers": False # 禁用DBSCAN
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
plot_config = {
|
||||
"main_plot": {},
|
||||
"subplots": {
|
||||
"&-buy_rsi": {"&-buy_rsi": {"color": "green"}},
|
||||
"&-sell_rsi": {"&-sell_rsi": {"color": "red"}},
|
||||
"&-stoploss": {"&-stoploss": {"color": "purple"}},
|
||||
"&-roi_0": {"&-roi_0": {"color": "orange"}},
|
||||
"do_predict": {"do_predict": {"color": "brown"}},
|
||||
},
|
||||
}
|
||||
|
||||
def feature_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: dict, **kwargs) -> DataFrame:
|
||||
# 保留关键的技术指标
|
||||
dataframe["rsi"] = ta.RSI(dataframe, timeperiod=14)
|
||||
|
||||
# 确保 MACD 列被正确计算并保留
|
||||
try:
|
||||
macd = ta.MACD(dataframe, fastperiod=12, slowperiod=26, signalperiod=9)
|
||||
dataframe["macd"] = macd["macd"]
|
||||
dataframe["macdsignal"] = macd["macdsignal"]
|
||||
except Exception as e:
|
||||
logger.error(f"计算 MACD 列时出错:{str(e)}")
|
||||
dataframe["macd"] = np.nan
|
||||
dataframe["macdsignal"] = np.nan
|
||||
|
||||
# 检查 MACD 列是否存在
|
||||
if "macd" not in dataframe.columns or "macdsignal" not in dataframe.columns:
|
||||
logger.error("MACD 或 MACD 信号列缺失,无法生成买入信号")
|
||||
raise ValueError("DataFrame 缺少必要的 MACD 列")
|
||||
|
||||
# 确保 MACD 列存在
|
||||
if "macd" not in dataframe.columns or "macdsignal" not in dataframe.columns:
|
||||
logger.error("MACD 或 MACD 信号列缺失,无法生成买入信号")
|
||||
raise ValueError("DataFrame 缺少必要的 MACD 列")
|
||||
|
||||
# 保留布林带相关特征
|
||||
bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=20, stds=2)
|
||||
dataframe["bb_lowerband"] = bollinger["lower"]
|
||||
dataframe["bb_middleband"] = bollinger["mid"]
|
||||
dataframe["bb_upperband"] = bollinger["upper"]
|
||||
|
||||
# 保留成交量相关特征
|
||||
dataframe["volume_ma"] = dataframe["volume"].rolling(window=20).mean()
|
||||
|
||||
# 数据清理
|
||||
for col in dataframe.columns:
|
||||
if dataframe[col].dtype in ["float64", "int64"]:
|
||||
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan)
|
||||
dataframe[col] = dataframe[col].ffill().fillna(0)
|
||||
|
||||
logger.info(f"特征工程完成,特征数量:{len(dataframe.columns)}")
|
||||
return dataframe
|
||||
|
||||
def feature_engineering_expand_basic(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
|
||||
dataframe["%-pct-change"] = dataframe["close"].pct_change()
|
||||
dataframe["%-raw_volume"] = dataframe["volume"]
|
||||
dataframe["%-raw_price"] = dataframe["close"]
|
||||
# 数据清理逻辑
|
||||
for col in dataframe.columns:
|
||||
if dataframe[col].dtype in ["float64", "int64"]:
|
||||
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], 0)
|
||||
dataframe[col] = dataframe[col].ffill()
|
||||
dataframe[col] = dataframe[col].fillna(0)
|
||||
|
||||
# 检查是否仍有无效值
|
||||
if dataframe[col].isna().any() or np.isinf(dataframe[col]).any():
|
||||
logger.warning(f"列 {col} 仍包含无效值,已填充为默认值")
|
||||
dataframe[col] = dataframe[col].fillna(0)
|
||||
return dataframe
|
||||
|
||||
def feature_engineering_standard(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
|
||||
dataframe["%-day_of_week"] = dataframe["date"].dt.dayofweek
|
||||
dataframe["%-hour_of_day"] = dataframe["date"].dt.hour
|
||||
dataframe.replace([np.inf, -np.inf], 0, inplace=True)
|
||||
dataframe.ffill(inplace=True)
|
||||
dataframe.fillna(0, inplace=True)
|
||||
return dataframe
|
||||
|
||||
def set_freqai_targets(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
|
||||
logger.info(f"设置 FreqAI 目标,交易对:{metadata['pair']}")
|
||||
if "close" not in dataframe.columns:
|
||||
logger.error("数据框缺少必要的 'close' 列")
|
||||
raise ValueError("数据框缺少必要的 'close' 列")
|
||||
|
||||
try:
|
||||
label_period = self.freqai_info["feature_parameters"]["label_period_candles"]
|
||||
|
||||
# 定义目标变量为未来价格变化百分比(连续值)
|
||||
dataframe["up_or_down"] = (
|
||||
dataframe["close"].shift(-label_period) - dataframe["close"]
|
||||
) / dataframe["close"]
|
||||
|
||||
# 数据清理:处理 NaN 和 Inf 值
|
||||
dataframe["up_or_down"] = dataframe["up_or_down"].replace([np.inf, -np.inf], np.nan)
|
||||
dataframe["up_or_down"] = dataframe["up_or_down"].ffill().fillna(0)
|
||||
|
||||
# 确保目标变量是二维数组
|
||||
if dataframe["up_or_down"].ndim == 1:
|
||||
dataframe["up_or_down"] = dataframe["up_or_down"].values.reshape(-1, 1)
|
||||
|
||||
# 检查并处理 NaN 或无限值
|
||||
dataframe["up_or_down"] = dataframe["up_or_down"].replace([np.inf, -np.inf], np.nan)
|
||||
dataframe["up_or_down"] = dataframe["up_or_down"].ffill().fillna(0)
|
||||
|
||||
# 生成 %-volatility 特征
|
||||
dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std()
|
||||
|
||||
# 确保 &-buy_rsi 列的值计算正确
|
||||
dataframe["&-buy_rsi"] = ta.RSI(dataframe, timeperiod=14)
|
||||
|
||||
# 数据清理
|
||||
for col in ["&-buy_rsi", "up_or_down", "%-volatility"]:
|
||||
# 使用直接操作避免链式赋值
|
||||
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan)
|
||||
dataframe[col] = dataframe[col].ffill() # 替代 fillna(method='ffill')
|
||||
dataframe[col] = dataframe[col].fillna(dataframe[col].mean()) # 使用均值填充 NaN 值
|
||||
if dataframe[col].isna().any():
|
||||
logger.warning(f"目标列 {col} 仍包含 NaN,填充为默认值")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建 FreqAI 目标失败:{str(e)}")
|
||||
raise
|
||||
|
||||
# Log the shape of the target variable for debugging
|
||||
logger.info(f"目标列形状:{dataframe['up_or_down'].shape}")
|
||||
logger.info(f"目标列预览:\n{dataframe[['up_or_down', '&-buy_rsi']].head().to_string()}")
|
||||
return dataframe
|
||||
|
||||
def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
|
||||
logger.info(f"处理交易对:{metadata['pair']}")
|
||||
dataframe = self.freqai.start(dataframe, metadata, self)
|
||||
|
||||
# 计算传统指标
|
||||
dataframe["rsi"] = ta.RSI(dataframe, timeperiod=14)
|
||||
bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=20, stds=2)
|
||||
dataframe["bb_lowerband"] = bollinger["lower"]
|
||||
dataframe["bb_middleband"] = bollinger["mid"]
|
||||
dataframe["bb_upperband"] = bollinger["upper"]
|
||||
dataframe["tema"] = ta.TEMA(dataframe, timeperiod=9)
|
||||
|
||||
# 生成 up_or_down 信号(非 FreqAI 目标)
|
||||
label_period = self.freqai_info["feature_parameters"]["label_period_candles"]
|
||||
# 使用未来价格变化方向生成 up_or_down 信号
|
||||
label_period = self.freqai_info["feature_parameters"]["label_period_candles"]
|
||||
dataframe["up_or_down"] = np.where(
|
||||
dataframe["close"].shift(-label_period) > dataframe["close"], 1, 0
|
||||
)
|
||||
|
||||
# 动态设置参数
|
||||
if "&-buy_rsi" in dataframe.columns:
|
||||
# 派生其他目标
|
||||
dataframe["&-sell_rsi"] = dataframe["&-buy_rsi"] + 30
|
||||
dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std()
|
||||
# Ensure proper calculation and handle potential NaN values
|
||||
dataframe["&-stoploss"] = (-0.1 - (dataframe["%-volatility"] * 10).clip(0, 0.25)).fillna(-0.1)
|
||||
dataframe["&-roi_0"] = ((dataframe["close"] / dataframe["close"].shift(label_period) - 1).clip(0, 0.2)).fillna(0)
|
||||
|
||||
# Additional check to ensure no NaN values remain
|
||||
for col in ["&-stoploss", "&-roi_0"]:
|
||||
if dataframe[col].isna().any():
|
||||
logger.warning(f"列 {col} 仍包含 NaN,填充为默认值")
|
||||
dataframe[col] = dataframe[col].fillna(-0.1 if col == "&-stoploss" else 0)
|
||||
|
||||
# 简化动态参数生成逻辑
|
||||
# 放松 buy_rsi 和 sell_rsi 的生成逻辑
|
||||
# 计算 buy_rsi_pred 并清理 NaN 值
|
||||
dataframe["buy_rsi_pred"] = dataframe["rsi"].rolling(window=10).mean().clip(30, 50)
|
||||
dataframe["buy_rsi_pred"] = dataframe["buy_rsi_pred"].fillna(dataframe["buy_rsi_pred"].median())
|
||||
|
||||
# 计算 sell_rsi_pred 并清理 NaN 值
|
||||
dataframe["sell_rsi_pred"] = dataframe["buy_rsi_pred"] + 20
|
||||
dataframe["sell_rsi_pred"] = dataframe["sell_rsi_pred"].fillna(dataframe["sell_rsi_pred"].median())
|
||||
|
||||
# 计算 stoploss_pred 并清理 NaN 值
|
||||
dataframe["stoploss_pred"] = -0.1 - (dataframe["%-volatility"] * 10).clip(0, 0.25)
|
||||
dataframe["stoploss_pred"] = dataframe["stoploss_pred"].fillna(dataframe["stoploss_pred"].mean())
|
||||
|
||||
# 计算 roi_0_pred 并清理 NaN 值
|
||||
dataframe["roi_0_pred"] = dataframe["&-roi_0"].clip(0.01, 0.2)
|
||||
dataframe["roi_0_pred"] = dataframe["roi_0_pred"].fillna(dataframe["roi_0_pred"].mean())
|
||||
|
||||
# 检查预测值
|
||||
for col in ["buy_rsi_pred", "sell_rsi_pred", "stoploss_pred", "roi_0_pred", "&-sell_rsi", "&-stoploss", "&-roi_0"]:
|
||||
if dataframe[col].isna().any():
|
||||
logger.warning(f"列 {col} 包含 NaN,填充为默认值")
|
||||
dataframe[col] = dataframe[col].fillna(dataframe[col].mean())
|
||||
|
||||
# 更保守的止损和止盈设置
|
||||
dataframe["trailing_stop_positive"] = (dataframe["roi_0_pred"] * 0.3).clip(0.01, 0.2)
|
||||
dataframe["trailing_stop_positive_offset"] = (dataframe["roi_0_pred"] * 0.5).clip(0.01, 0.3)
|
||||
|
||||
# 设置策略级参数
|
||||
self.buy_rsi.value = float(dataframe["buy_rsi_pred"].iloc[-1])
|
||||
self.sell_rsi.value = float(dataframe["sell_rsi_pred"].iloc[-1])
|
||||
# 更保守的止损设置
|
||||
self.stoploss = -0.15 # 固定止损 15%
|
||||
self.minimal_roi = {
|
||||
0: float(self.roi_0.value),
|
||||
15: float(self.roi_15.value),
|
||||
30: float(self.roi_30.value),
|
||||
60: 0
|
||||
}
|
||||
# 更保守的追踪止损设置
|
||||
self.trailing_stop_positive = 0.05 # 追踪止损触发点
|
||||
self.trailing_stop_positive_offset = 0.1 # 追踪止损偏移量
|
||||
|
||||
logger.info(f"动态参数:buy_rsi={self.buy_rsi.value}, sell_rsi={self.sell_rsi.value}, "
|
||||
f"stoploss={self.stoploss}, trailing_stop_positive={self.trailing_stop_positive}")
|
||||
|
||||
dataframe.replace([np.inf, -np.inf], 0, inplace=True)
|
||||
dataframe.ffill(inplace=True)
|
||||
dataframe.fillna(0, inplace=True)
|
||||
|
||||
logger.info(f"up_or_down 值统计:\n{dataframe['up_or_down'].value_counts().to_string()}")
|
||||
logger.info(f"do_predict 值统计:\n{dataframe['do_predict'].value_counts().to_string()}")
|
||||
|
||||
return dataframe
|
||||
|
||||
def populate_exit_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
|
||||
# 改进卖出信号条件
|
||||
exit_long_conditions = [
|
||||
(df["rsi"] > df["sell_rsi_pred"]), # RSI 高于卖出阈值
|
||||
(df["volume"] > df["volume"].rolling(window=10).mean()), # 成交量高于近期均值
|
||||
(df["close"] < df["bb_middleband"]) # 价格低于布林带中轨
|
||||
]
|
||||
if exit_long_conditions:
|
||||
df.loc[
|
||||
reduce(lambda x, y: x & y, exit_long_conditions),
|
||||
"exit_long"
|
||||
] = 1
|
||||
return df
|
||||
def populate_entry_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
|
||||
# 改进买入信号条件
|
||||
# 检查 MACD 列是否存在
|
||||
if "macd" not in df.columns or "macdsignal" not in df.columns:
|
||||
logger.error("MACD 或 MACD 信号列缺失,无法生成买入信号。尝试重新计算 MACD 列。")
|
||||
|
||||
try:
|
||||
macd = ta.MACD(df, fastperiod=12, slowperiod=26, signalperiod=9)
|
||||
df["macd"] = macd["macd"]
|
||||
df["macdsignal"] = macd["macdsignal"]
|
||||
logger.info("MACD 列已成功重新计算。")
|
||||
except Exception as e:
|
||||
logger.error(f"重新计算 MACD 列时出错:{str(e)}")
|
||||
raise ValueError("DataFrame 缺少必要的 MACD 列且无法重新计算。")
|
||||
|
||||
enter_long_conditions = [
|
||||
(df["rsi"] < df["buy_rsi_pred"]), # RSI 低于买入阈值
|
||||
(df["volume"] > df["volume"].rolling(window=10).mean() * 1.2), # 成交量高于近期均值20%
|
||||
(df["close"] > df["bb_middleband"]) # 价格高于布林带中轨
|
||||
]
|
||||
|
||||
# 如果 MACD 列存在,则添加 MACD 金叉条件
|
||||
if "macd" in df.columns and "macdsignal" in df.columns:
|
||||
enter_long_conditions.append((df["macd"] > df["macdsignal"]))
|
||||
|
||||
# 确保模型预测为买入
|
||||
enter_long_conditions.append((df["do_predict"] == 1))
|
||||
if enter_long_conditions:
|
||||
df.loc[
|
||||
reduce(lambda x, y: x & y, enter_long_conditions),
|
||||
["enter_long", "enter_tag"]
|
||||
] = (1, "long")
|
||||
return df
|
||||
def confirm_trade_entry(
|
||||
self, pair: str, order_type: str, amount: float, rate: float,
|
||||
time_in_force: str, current_time, entry_tag, side: str, **kwargs
|
||||
) -> bool:
|
||||
df, _ = self.dp.get_analyzed_dataframe(pair, self.timeframe)
|
||||
last_candle = df.iloc[-1].squeeze()
|
||||
if side == "long":
|
||||
if rate > (last_candle["close"] * (1 + 0.0025)):
|
||||
return False
|
||||
return True
|
||||
1
result/backtest-result-2025-05-01_15-11-54.json
Normal file
1
result/backtest-result-2025-05-01_15-11-54.json
Normal file
File diff suppressed because one or more lines are too long
1
result/backtest-result-2025-05-01_15-11-54.meta.json
Normal file
1
result/backtest-result-2025-05-01_15-11-54.meta.json
Normal file
@ -0,0 +1 @@
|
||||
{"OKXRegressionStrategy":{"run_id":"852eb93539f3e08ed409fd31ad415f0672f9577a","backtest_start_time":1746112131,"timeframe":"3m","timeframe_detail":null,"backtest_start_ts":1743465600,"backtest_end_ts":1745107200}}
|
||||
@ -0,0 +1,313 @@
|
||||
import logging
|
||||
from freqtrade.strategy import IStrategy
|
||||
from pandas import DataFrame
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import talib as ta
|
||||
from typing import Dict, List, Optional
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from xgboost import XGBRegressor
|
||||
import ccxt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OKXRegressionStrategy(IStrategy):
|
||||
"""
|
||||
Freqtrade AI 策略,使用回归模型进行 OKX 数据上的仅做多交易。
|
||||
- 数据通过 CCXT 从 OKX 交易所获取。
|
||||
- 使用 XGBoost 回归模型预测价格变化。
|
||||
- 仅生成做多(买入)信号,不做空。
|
||||
- 适配 Freqtrade 2025.3,继承 IStrategy。
|
||||
"""
|
||||
|
||||
# 指标所需的最大启动蜡烛数
|
||||
startup_candle_count: int = 20
|
||||
|
||||
# 策略元数据(建议通过 config.json 配置)
|
||||
trailing_stop = True
|
||||
trailing_stop_positive = 0.01
|
||||
max_open_trades = 3
|
||||
stake_amount = 'dynamic'
|
||||
|
||||
# FreqAI 配置
|
||||
freqai_config = {
|
||||
"enabled": True,
|
||||
"identifier": "okx_regression_v1",
|
||||
"model_training_parameters": {
|
||||
"n_estimators": 100,
|
||||
"learning_rate": 0.05,
|
||||
"max_depth": 6
|
||||
},
|
||||
"feature_parameters": {
|
||||
"include_timeframes": ["5m", "15m", "1h"],
|
||||
"include_corr_pairlist": ["BTC/USDT", "ETH/USDT"],
|
||||
"label_period_candles": 12,
|
||||
"include_shifted_candles": 2, # 添加历史偏移特征
|
||||
"principal_component_analysis": True # 启用 PCA
|
||||
},
|
||||
"data_split_parameters": {
|
||||
"test_size": 0.2,
|
||||
"shuffle": False
|
||||
},
|
||||
"train_period_days": 90,
|
||||
"backtest_period_days": 30,
|
||||
"purge_old_models": True # 清理旧模型
|
||||
}
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
super().__init__(config)
|
||||
# 初始化特征缓存
|
||||
self.feature_cache = {}
|
||||
|
||||
def feature_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: Dict, **kwargs) -> DataFrame:
|
||||
"""
|
||||
为每个时间框架和相关交易对生成特征。
|
||||
"""
|
||||
cache_key = f"{metadata.get('pair', 'unknown')}_{period}"
|
||||
if cache_key in self.feature_cache:
|
||||
logger.debug(f"使用缓存特征:{cache_key}")
|
||||
return self.feature_cache[cache_key]
|
||||
|
||||
# RSI
|
||||
dataframe[f"%-%-rsi-{period}"] = ta.RSI(dataframe["close"], timeperiod=period)
|
||||
|
||||
# MACD
|
||||
macd, macdsignal, _ = ta.MACD(dataframe["close"], fastperiod=12, slowperiod=26, signalperiod=9)
|
||||
dataframe[f"%-%-macd-{period}"] = macd
|
||||
dataframe[f"%-%-macdsignal-{period}"] = macdsignal
|
||||
|
||||
# 布林带宽度
|
||||
upper, middle, lower = ta.BBANDS(dataframe["close"], timeperiod=period)
|
||||
dataframe[f"%-%-bb_width-{period}"] = (upper - lower) / middle
|
||||
|
||||
# 成交量均线
|
||||
dataframe[f"%-%-volume_ma-{period}"] = ta.SMA(dataframe["volume"], timeperiod=period)
|
||||
|
||||
# OKX 订单簿特征(通过 CCXT)
|
||||
try:
|
||||
order_book = self.fetch_okx_order_book(metadata["pair"])
|
||||
dataframe[f"%-%-order_book_imbalance-{period}"] = (
|
||||
order_book["bids"] - order_book["asks"]
|
||||
) / (order_book["bids"] + order_book["asks"] + 1e-10)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取 {metadata['pair']} 订单簿失败:{str(e)}")
|
||||
dataframe[f"%-%-order_book_imbalance-{period}"] = 0.0
|
||||
|
||||
# 数据清洗
|
||||
dataframe = dataframe.replace([np.inf, -np.inf], np.nan)
|
||||
dataframe = dataframe.ffill().fillna(0)
|
||||
|
||||
# 缓存特征
|
||||
self.feature_cache[cache_key] = dataframe.copy()
|
||||
logger.debug(f"周期 {period} 特征:{list(dataframe.filter(like='%-%-').columns)}")
|
||||
|
||||
return dataframe
|
||||
|
||||
def feature_engineering_standard(self, dataframe: DataFrame, **kwargs) -> DataFrame:
|
||||
"""
|
||||
添加基础时间框架的全局特征。
|
||||
"""
|
||||
# 确保索引是 DatetimeIndex
|
||||
if not isinstance(dataframe.index, pd.DatetimeIndex):
|
||||
dataframe = dataframe.set_index(pd.DatetimeIndex(dataframe.index))
|
||||
|
||||
# 价格变化率
|
||||
dataframe["%-price_change"] = dataframe["close"].pct_change()
|
||||
|
||||
# 时间特征:小时
|
||||
dataframe["%-hour_of_day"] = dataframe.index.hour / 24.0
|
||||
|
||||
# 数据清洗
|
||||
dataframe = dataframe.replace([np.inf, -np.inf], np.nan)
|
||||
dataframe = dataframe.ffill().fillna(0)
|
||||
|
||||
logger.debug(f"全局特征:{list(dataframe.filter(like='%-%-').columns)}")
|
||||
return dataframe
|
||||
|
||||
def set_freqai_targets(self, dataframe: DataFrame, metadata: Dict, **kwargs) -> DataFrame:
|
||||
"""
|
||||
设置回归模型的目标变量。
|
||||
"""
|
||||
label_period = self.freqai_config["feature_parameters"]["label_period_candles"]
|
||||
# 目标:未来价格变化
|
||||
dataframe["&-s_close"] = (
|
||||
dataframe["close"].shift(-label_period) - dataframe["close"]
|
||||
) / dataframe["close"]
|
||||
|
||||
# ROI 目标
|
||||
for minutes in [0, 15, 30]:
|
||||
candles = int(minutes / 5)
|
||||
if candles > 0:
|
||||
dataframe[f"&-roi_{minutes}"] = (
|
||||
dataframe["close"].shift(-candles) - dataframe["close"]
|
||||
) / dataframe["close"]
|
||||
else:
|
||||
dataframe[f"&-roi_{minutes}"] = 0.0
|
||||
|
||||
# 动态阈值
|
||||
dataframe["&-buy_rsi_pred"] = ta.RSI(dataframe["close"], timeperiod=14).rolling(20).mean()
|
||||
dataframe["&-stoploss_pred"] = -0.05
|
||||
dataframe["&-roi_0_pred"] = 0.03
|
||||
|
||||
# 数据清洗
|
||||
dataframe = dataframe.replace([np.inf, -np.inf], np.nan)
|
||||
dataframe = dataframe.fillna(method="ffill").fillna(0)
|
||||
|
||||
# 验证目标
|
||||
required_targets = ["&-s_close", "&-roi_0", "&-buy_rsi_pred", "&-stoploss_pred", "&-roi_0_pred"]
|
||||
missing_targets = [col for col in required_targets if col not in dataframe.columns]
|
||||
if missing_targets:
|
||||
logger.error(f"缺少目标列:{missing_targets}")
|
||||
raise ValueError(f"目标初始化失败:{missing_targets}")
|
||||
|
||||
logger.debug(f"目标初始化完成。DataFrame 形状:{dataframe.shape}")
|
||||
return dataframe
|
||||
|
||||
def populate_indicators(self, dataframe: DataFrame, metadata: Dict) -> DataFrame:
|
||||
"""
|
||||
使用 FreqAI 生成指标和预测。
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"FreqAI 对象:{type(self.freqai)}")
|
||||
dataframe = self.freqai.start(dataframe, metadata, self)
|
||||
|
||||
# 验证数据完整性
|
||||
if dataframe["close"].isna().any() or dataframe["volume"].isna().any():
|
||||
logger.warning("检测到 OKX 数据缺失,使用前向填充")
|
||||
dataframe = dataframe.ffill().fillna(0)
|
||||
|
||||
# 预测统计
|
||||
if "&-s_close" in dataframe.columns:
|
||||
logger.debug(f"预测统计:均值={dataframe['&-s_close'].mean():.4f}, "
|
||||
f"方差={dataframe['&-s_close'].var():.4f}")
|
||||
|
||||
logger.debug(f"生成的列:{list(dataframe.columns)}")
|
||||
return dataframe
|
||||
except Exception as e:
|
||||
logger.error(f"FreqAI start 失败:{str(e)}")
|
||||
raise
|
||||
finally:
|
||||
logger.debug("populate_indicators 完成")
|
||||
|
||||
# 确保返回 DataFrame,防止 None
|
||||
if dataframe is None:
|
||||
dataframe = DataFrame()
|
||||
return dataframe
|
||||
|
||||
|
||||
def populate_entry_trend(self, dataframe: DataFrame, metadata: Dict) -> DataFrame:
|
||||
"""
|
||||
基于回归预测生成仅做多信号。
|
||||
"""
|
||||
# 确保 "%-%-rsi-14" 列存在
|
||||
if "%-%-rsi-14" not in dataframe.columns:
|
||||
dataframe["%-%-rsi-14"] = 0.0
|
||||
|
||||
dataframe.loc[
|
||||
(
|
||||
(dataframe["&-s_close"] > 0.01) & # 预测价格上涨 > 1%
|
||||
(dataframe["do_predict"] == 1) & # 预测可靠
|
||||
(dataframe["%-%-rsi-14"] < dataframe["&-buy_rsi_pred"]) # RSI 低于动态阈值
|
||||
),
|
||||
"enter_long"] = 1
|
||||
logger.debug(f"生成 {dataframe['enter_long'].sum()} 个做多信号")
|
||||
return dataframe
|
||||
|
||||
def populate_exit_trend(self, dataframe: DataFrame, metadata: Dict) -> DataFrame:
|
||||
"""
|
||||
基于 ROI 或止损生成退出信号。
|
||||
"""
|
||||
dataframe.loc[
|
||||
(
|
||||
(dataframe["&-roi_0_pred"] <= dataframe["close"].pct_change()) | # 达到 ROI
|
||||
(dataframe["close"].pct_change() <= dataframe["&-stoploss_pred"]) # 达到止损
|
||||
),
|
||||
"exit_long"] = 1
|
||||
logger.debug(f"生成 {dataframe['exit_long'].sum()} 个退出信号")
|
||||
return dataframe
|
||||
|
||||
def custom_stake_amount(self, pair: str, current_time: 'datetime', current_rate: float,
|
||||
proposed_stake: float, min_stake: float, max_stake: float,
|
||||
entry_tag: Optional[str], **kwargs) -> float:
|
||||
"""
|
||||
动态下注:每笔交易占账户余额 2%。
|
||||
"""
|
||||
balance = self.wallets.get_available_stake_amount()
|
||||
stake = balance * 0.02
|
||||
return min(max(stake, min_stake), max_stake)
|
||||
|
||||
def leverage(self, pair: str, current_time: 'datetime', current_rate: float,
|
||||
proposed_leverage: float, max_leverage: float, side: str,
|
||||
**kwargs) -> float:
|
||||
"""
|
||||
禁用杠杆,仅做多。
|
||||
"""
|
||||
return 1.0
|
||||
|
||||
def confirm_trade_entry(self, pair: str, order_type: str, amount: float, rate: float,
|
||||
time_in_force: str, current_time: 'datetime', **kwargs) -> bool:
|
||||
"""
|
||||
验证交易进入,检查 OKX 数据新鲜度。
|
||||
"""
|
||||
if not self.check_data_freshness(pair, current_time):
|
||||
logger.warning(f"{pair} 的 OKX 数据过期,跳过交易")
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_data_freshness(self, pair: str, current_time: 'datetime') -> bool:
|
||||
"""
|
||||
简化版数据新鲜度检查,不依赖外部 API。
|
||||
"""
|
||||
# 假设数据总是新鲜的(用于测试)
|
||||
return True
|
||||
|
||||
def fetch_okx_order_book(self, pair: str, limit: int = 10) -> Dict:
|
||||
"""
|
||||
获取 OKX 订单簿数据。
|
||||
"""
|
||||
try:
|
||||
exchange = ccxt.okx(self.config["exchange"]["ccxt_config"])
|
||||
order_book = exchange.fetch_order_book(pair, limit=limit)
|
||||
bids = sum([bid[1] for bid in order_book["bids"]])
|
||||
asks = sum([ask[1] for ask in order_book["asks"]])
|
||||
return {"bids": bids, "asks": asks}
|
||||
except Exception as e:
|
||||
logger.error(f"获取 {pair} 订单簿失败:{str(e)}")
|
||||
return {"bids": 0, "asks": 0}
|
||||
|
||||
def fit(self, data_dictionary: Dict, metadata: Dict, **kwargs) -> None:
|
||||
"""
|
||||
训练回归模型并记录性能。
|
||||
"""
|
||||
try:
|
||||
# 初始化模型
|
||||
if not hasattr(self, 'model') or self.model is None:
|
||||
model_params = self.freqai_config["model_training_parameters"]
|
||||
self.model = XGBRegressor(**model_params)
|
||||
logger.debug("初始化新的 XGBoost 回归模型")
|
||||
|
||||
# 调用 FreqAI 训练
|
||||
self.freqai.fit(data_dictionary, metadata, **kwargs)
|
||||
|
||||
# 记录训练集性能
|
||||
train_data = data_dictionary["train_features"]
|
||||
train_labels = data_dictionary["train_labels"]
|
||||
train_predictions = self.model.predict(train_data)
|
||||
train_mse = mean_squared_error(train_labels, train_predictions)
|
||||
logger.info(f"训练集 MSE:{train_mse:.6f}")
|
||||
|
||||
# 记录测试集性能(如果可用)
|
||||
if "test_features" in data_dictionary:
|
||||
test_data = data_dictionary["test_features"]
|
||||
test_labels = data_dictionary["test_labels"]
|
||||
test_predictions = self.model.predict(test_data)
|
||||
test_mse = mean_squared_error(test_labels, test_predictions)
|
||||
logger.info(f"测试集 MSE:{test_mse:.6f}")
|
||||
|
||||
# 特征重要性
|
||||
if hasattr(self.model, 'feature_importances_'):
|
||||
importance = self.model.feature_importances_
|
||||
logger.debug(f"特征重要性:{dict(zip(train_data.columns, importance))}")
|
||||
except Exception as e:
|
||||
logger.error(f"FreqAI fit 失败:{str(e)}")
|
||||
raise
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user