myTestFreqAI/freqtrade/templates/FreqaiExampleStrategy.py
zhangkun9038@dingtalk.com a46895526c stable1 11
2025-04-28 20:51:06 +08:00

231 lines
12 KiB
Python

import logging
import numpy as np
from functools import reduce
import talib.abstract as ta
from pandas import DataFrame
from technical import qtpylib
from freqtrade.strategy import IStrategy, IntParameter, DecimalParameter
logger = logging.getLogger(__name__)
class FreqaiExampleStrategy(IStrategy):
minimal_roi = {}
stoploss = 0.0
trailing_stop = True
process_only_new_candles = True
use_exit_signal = True
startup_candle_count: int = 40
can_short = False
# 参数定义
buy_rsi = IntParameter(low=10, high=50, default=30, space="buy", optimize=True, load=True)
sell_rsi = IntParameter(low=50, high=90, default=70, space="sell", optimize=True, load=True)
roi_0 = DecimalParameter(low=0.01, high=0.1, default=0.05, space="roi", optimize=True, load=True)
roi_15 = DecimalParameter(low=0.005, high=0.05, default=0.03, space="roi", optimize=True, load=True)
roi_30 = DecimalParameter(low=0.001, high=0.03, default=0.01, space="roi", optimize=True, load=True)
stoploss_param = DecimalParameter(low=-0.2, high=-0.05, default=-0.1, space="stoploss", optimize=True, load=True)
# FreqAI 配置
freqai_info = {
"model": "XGBoostRegressor",
"return_type": "raw", # 确保使用原始回归输出
"feature_parameters": {
"include_timeframes": ["15m", "1h", "4h"],
"include_corr_pairlist": ["BTC/USDT", "SOL/USDT"],
"label_period_candles": 20, # 增加预测周期
"include_shifted_candles": 2,
"weight_factor": 0.9,
"principal_component_analysis": False,
"use_SVM_to_remove_outliers": True,
"SVM_parameters": {"nu": 0.1},
},
"data_split_parameters": {
"test_size": 0.2,
"shuffle": True,
},
"model_training_parameters": {
"n_estimators": 300, # 增加迭代次数
"learning_rate": 0.03, # 降低学习率
"max_depth": 6,
"subsample": 0.8,
"colsample_bytree": 0.8,
"reg_lambda": 1.0,
"objective": "reg:squarederror",
"eval_metric": "rmse",
"early_stopping_rounds": 20,
"verbose": 0,
},
"data_kitchen": {
"feature_parameters": {
"DI_threshold": 0 # 禁用 DI 过滤
}
}
}
def calculate_macd(self, dataframe: DataFrame) -> DataFrame:
try:
macd = ta.MACD(dataframe, fastperiod=12, slowperiod=26, signalperiod=9)
dataframe["macd"] = macd["macd"].replace([np.inf, -np.inf], np.nan).ffill().fillna(0)
dataframe["macdsignal"] = macd["macdsignal"].replace([np.inf, -np.inf], np.nan).ffill().fillna(0)
logger.info("MACD calculated successfully.")
except Exception as e:
logger.error(f"Error calculating MACD: {str(e)}")
dataframe["macd"] = np.nan
dataframe["macdsignal"] = np.nan
return dataframe
def feature_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: dict, **kwargs) -> DataFrame:
# 增加趋势性特征
dataframe["rsi"] = ta.RSI(dataframe, timeperiod=14)
dataframe["sma20"] = ta.SMA(dataframe["close"], timeperiod=20)
dataframe["ema50"] = ta.EMA(dataframe["close"], timeperiod=50)
dataframe["atr"] = ta.ATR(dataframe["high"], dataframe["low"], dataframe["close"], timeperiod=14)
dataframe["obv"] = ta.OBV(dataframe["close"], dataframe["volume"])
dataframe["adx"] = ta.ADX(dataframe["high"], dataframe["low"], dataframe["close"], timeperiod=14)
dataframe["price_sma_diff"] = (dataframe["close"] - dataframe["sma20"]) / dataframe["sma20"]
dataframe["momentum"] = ta.MOM(dataframe["close"], timeperiod=14)
bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=20, stds=2)
dataframe["bb_lowerband"] = bollinger["lower"]
dataframe["bb_middleband"] = bollinger["mid"]
dataframe["bb_upperband"] = bollinger["upper"]
dataframe["bb_width"] = (bollinger["upper"] - bollinger["lower"]) / bollinger["mid"]
# 添加相关交易对特征
if metadata["pair"] == "SOL/USDT":
btc_data = self.dp.get_pair_dataframe(pair="BTC/USDT", timeframe=self.timeframe)
if not btc_data.empty:
dataframe["btc_rsi"] = ta.RSI(btc_data, timeperiod=14)
dataframe["btc_price_change"] = btc_data["close"].pct_change()
# 数据清理
for col in dataframe.columns:
if dataframe[col].dtype in ["float64", "int64"]:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan).ffill().fillna(dataframe[col].median())
logger.info(f"Feature engineering completed, features: {len(dataframe.columns)}")
return dataframe
def feature_engineering_expand_basic(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
dataframe["%-pct-change"] = dataframe["close"].pct_change()
dataframe["%-raw_volume"] = dataframe["volume"]
dataframe["%-raw_price"] = dataframe["close"]
for col in dataframe.columns:
if dataframe[col].dtype in ["float64", "int64"]:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan).ffill().fillna(dataframe[col].median())
return dataframe
def feature_engineering_standard(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
dataframe["%-day_of_week"] = dataframe["date"].dt.dayofweek
dataframe["%-hour_of_day"] = dataframe["date"].dt.hour
for col in dataframe.columns:
if dataframe[col].dtype in ["float64", "int64"]:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan).ffill().fillna(dataframe[col].median())
return dataframe
def set_freqai_targets(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
logger.info(f"Setting FreqAI targets for pair: {metadata['pair']}")
if "close" not in dataframe.columns:
logger.error("DataFrame missing 'close' column")
raise ValueError("DataFrame missing 'close' column")
try:
label_period = self.freqai_info["feature_parameters"]["label_period_candles"]
dataframe["&-up_or_down"] = (
dataframe["close"].shift(-label_period) - dataframe["close"]
) / dataframe["close"]
dataframe["&-up_or_down"] = dataframe["&-up_or_down"].replace([np.inf, -np.inf], np.nan).ffill().fillna(0)
dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std()
dataframe["&-buy_rsi"] = ta.RSI(dataframe, timeperiod=14)
for col in ["&-buy_rsi", "&-up_or_down", "%-volatility"]:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan).ffill().fillna(dataframe[col].median())
except Exception as e:
logger.error(f"Failed to create FreqAI targets: {str(e)}")
raise
logger.info(f"Target column shape: {dataframe['&-up_or_down'].shape}")
logger.info(f"Target preview:\n{dataframe[['&-up_or_down', '&-buy_rsi']].head().to_string()}")
return dataframe
def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
logger.info(f"Processing pair: {metadata['pair']}")
dataframe = self.freqai.start(dataframe, metadata, self)
# 计算传统指标
dataframe["rsi"] = ta.RSI(dataframe, timeperiod=14)
dataframe["atr"] = ta.ATR(dataframe["high"], dataframe["low"], dataframe["close"], timeperiod=14)
bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=20, stds=2)
dataframe["bb_lowerband"] = bollinger["lower"]
dataframe["bb_middleband"] = bollinger["mid"]
dataframe["bb_upperband"] = bollinger["upper"]
dataframe = self.calculate_macd(dataframe)
# 动态参数设置
if "&-buy_rsi" in dataframe.columns:
dataframe["&-sell_rsi"] = dataframe["&-buy_rsi"] + 20
dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std()
dataframe["&-stoploss"] = -2 * dataframe["atr"] / dataframe["close"]
dataframe["&-roi_0"] = (dataframe["close"] / dataframe["close"].shift(10) - 1).clip(0, 0.1).fillna(0)
# 动态预测值
dataframe["buy_rsi_pred"] = dataframe["rsi"].rolling(window=10).mean().clip(20, 50).fillna(dataframe["rsi"].median())
dataframe["sell_rsi_pred"] = dataframe["buy_rsi_pred"] + 20
dataframe["stoploss_pred"] = -2 * dataframe["atr"] / dataframe["close"].clip(-0.2, -0.05)
dataframe["roi_0_pred"] = dataframe["&-roi_0"].clip(0.01, 0.1).fillna(dataframe["&-roi_0"].mean())
for col in ["&-stoploss", "&-roi_0", "buy_rsi_pred", "sell_rsi_pred", "stoploss_pred", "roi_0_pred"]:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan).ffill().fillna(dataframe[col].mean())
# 设置策略参数
self.buy_rsi.value = float(dataframe["buy_rsi_pred"].iloc[-1])
self.sell_rsi.value = float(dataframe["sell_rsi_pred"].iloc[-1])
self.stoploss = float(dataframe["stoploss_pred"].iloc[-1])
self.minimal_roi = {
0: float(self.roi_0.value),
15: float(self.roi_15.value),
30: float(self.roi_30.value),
60: 0
}
# 动态追踪止损
self.trailing_stop_positive = float(2 * dataframe["atr"].iloc[-1] / dataframe["close"].iloc[-1])
self.trailing_stop_positive_offset = float(3 * dataframe["atr"].iloc[-1] / dataframe["close"].iloc[-1])
dataframe.ffill(inplace=True)
dataframe.fillna(dataframe.mean(numeric_only=True), inplace=True)
logger.info(f"&-up_or_down stats:\n{dataframe['&-up_or_down'].describe().to_string()}")
return dataframe
def populate_entry_trend(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
dataframe = self.calculate_macd(dataframe)
enter_long_conditions = [
(dataframe["rsi"] < dataframe["buy_rsi_pred"]),
(dataframe["volume"] > dataframe["volume"].rolling(window=10).mean() * 1.05),
(dataframe["close"] > dataframe["bb_middleband"]),
(dataframe["macd"] > dataframe["macdsignal"]),
(dataframe["&-up_or_down"] > 0.003) # 降低阈值
]
if enter_long_conditions:
dataframe.loc[reduce(lambda x, y: x & y, enter_long_conditions), ["enter_long", "enter_tag"]] = (1, "long")
return dataframe
def populate_exit_trend(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
exit_long_conditions = [
(dataframe["rsi"] > dataframe["sell_rsi_pred"]),
(dataframe["close"] < dataframe["bb_middleband"]),
(dataframe["macd"] < dataframe["macdsignal"]),
(dataframe["&-up_or_down"] < -0.003)
]
if exit_long_conditions:
dataframe.loc[reduce(lambda x, y: x | y, exit_long_conditions), "exit_long"] = 1
return dataframe
def confirm_trade_entry(
self, pair: str, order_type: str, amount: float, rate: float,
time_in_force: str, current_time, entry_tag, side: str, **kwargs
) -> bool:
df, _ = self.dp.get_analyzed_dataframe(pair, self.timeframe)
last_candle = df.iloc[-1].squeeze()
if side == "long":
if rate > (last_candle["close"] * (1 + 0.0025)):
return False
return True