myTestFreqAI/freqtrade/templates/FreqaiExampleStrategy.py
zhangkun9038@dingtalk.com 9bb14377ed stable1 14
2025-04-28 21:07:51 +08:00

290 lines
15 KiB
Python

import logging
import numpy as np
import pandas as pd
from functools import reduce
import talib.abstract as ta
from pandas import DataFrame
from technical import qtpylib
from freqtrade.strategy import IStrategy, IntParameter, DecimalParameter
logger = logging.getLogger(__name__)
class FreqaiExampleStrategy(IStrategy):
"""
FreqAI-based trading strategy using XGBoostRegressor for regression-based price movement prediction.
Optimized for short-term trading on spot markets (BTC/USDT, ETH/USDT, SOL/USDT).
Key improvements:
- Fixed KeyError in populate_entry_trend by using pd.concat for conditions
- Dynamic ATR-based stop-loss and ROI
- Enhanced feature engineering with cross-timeframe indicators
- Standardized and transformed target values to amplify signals
- Disabled DI filtering to resolve datasieve warnings
"""
# Strategy parameters
minimal_roi = {}
stoploss = 0.0
trailing_stop = True
process_only_new_candles = True
use_exit_signal = True
startup_candle_count: int = 40
can_short = False
timeframe = "15m"
# Hyperopt parameters
buy_rsi = IntParameter(low=10, high=50, default=30, space="buy", optimize=True, load=True)
sell_rsi = IntParameter(low=50, high=90, default=70, space="sell", optimize=True, load=True)
roi_0 = DecimalParameter(low=0.01, high=0.1, default=0.05, space="roi", optimize=True, load=True)
roi_15 = DecimalParameter(low=0.005, high=0.05, default=0.03, space="roi", optimize=True, load=True)
roi_30 = DecimalParameter(low=0.001, high=0.03, default=0.01, space="roi", optimize=True, load=True)
stoploss_param = DecimalParameter(low=-0.2, high=-0.05, default=-0.1, space="stoploss", optimize=True, load=True)
# FreqAI configuration
freqai_info = {
"model": "XGBoostRegressor",
"return_type": "raw", # Use raw regression predictions
"feature_parameters": {
"include_timeframes": ["15m", "1h", "4h"],
"include_corr_pairlist": ["BTC/USDT", "SOL/USDT", "ETH/USDT"],
"label_period_candles": 60, # Extended prediction horizon
"include_shifted_candles": 2,
"weight_factor": 0.9,
"principal_component_analysis": False,
"use_SVM_to_remove_outliers": True,
"SVM_parameters": {"nu": 0.1},
"DI_threshold": 0, # Disable DI filtering
"indicator_periods_candles": [14, 20]
},
"data_split_parameters": {
"test_size": 0.2,
"shuffle": True,
},
"model_training_parameters": {
"n_estimators": 300,
"learning_rate": 0.03,
"max_depth": 6,
"subsample": 0.8,
"colsample_bytree": 0.8,
"reg_lambda": 1.0,
"objective": "reg:squarederror",
"eval_metric": "rmse",
"early_stopping_rounds": 20,
"verbose": 0,
},
"data_kitchen": {
"feature_parameters": {
"DI_threshold": 0 # Ensure DI filtering is disabled
}
}
}
def calculate_macd(self, dataframe: DataFrame) -> DataFrame:
"""Calculate MACD indicators and handle exceptions."""
try:
macd = ta.MACD(dataframe, fastperiod=12, slowperiod=26, signalperiod=9)
dataframe["macd"] = macd["macd"].replace([np.inf, -np.inf], np.nan).ffill().fillna(0)
dataframe["macdsignal"] = macd["macdsignal"].replace([np.inf, -np.inf], np.nan).ffill().fillna(0)
logger.info("MACD calculated successfully.")
except Exception as e:
logger.error(f"Error calculating MACD: {str(e)}")
dataframe["macd"] = np.nan
dataframe["macdsignal"] = np.nan
return dataframe
def feature_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: dict, **kwargs) -> DataFrame:
"""Enhanced feature engineering with cross-timeframe and momentum indicators."""
# Standard technical indicators
dataframe["rsi"] = ta.RSI(dataframe, timeperiod=14)
dataframe["sma20"] = ta.SMA(dataframe["close"], timeperiod=20)
dataframe["ema50"] = ta.EMA(dataframe["close"], timeperiod=50)
dataframe["atr"] = ta.ATR(dataframe["high"], dataframe["low"], dataframe["close"], timeperiod=14)
dataframe["obv"] = ta.OBV(dataframe["close"], dataframe["volume"])
dataframe["adx"] = ta.ADX(dataframe["high"], dataframe["low"], dataframe["close"], timeperiod=14)
dataframe["momentum"] = ta.MOM(dataframe["close"], timeperiod=14)
dataframe["price_sma_diff"] = (dataframe["close"] - dataframe["sma20"]) / dataframe["sma20"]
# Bollinger Bands
bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=20, stds=2)
dataframe["bb_lowerband"] = bollinger["lower"]
dataframe["bb_middleband"] = bollinger["mid"]
dataframe["bb_upperband"] = bollinger["upper"]
dataframe["bb_width"] = (bollinger["upper"] - bollinger["lower"]) / bollinger["mid"]
# Cross-timeframe features
for tf in ["1h", "4h"]:
tf_data = self.dp.get_pair_dataframe(pair=metadata["pair"], timeframe=tf)
if not tf_data.empty:
dataframe[f"rsi_{tf}"] = ta.RSI(tf_data, timeperiod=14)
bollinger_tf = qtpylib.bollinger_bands(qtpylib.typical_price(tf_data), window=20, stds=2)
dataframe[f"bb_width_{tf}"] = (bollinger_tf["upper"] - bollinger_tf["lower"]) / bollinger_tf["mid"]
# Correlated pair features
if metadata["pair"] == "SOL/USDT":
btc_data = self.dp.get_pair_dataframe(pair="BTC/USDT", timeframe=self.timeframe)
if not btc_data.empty:
dataframe["btc_rsi"] = ta.RSI(btc_data, timeperiod=14)
dataframe["btc_price_change"] = btc_data["close"].pct_change()
dataframe["btc_volatility"] = btc_data["close"].pct_change().rolling(20).std()
# Data cleaning
for col in dataframe.columns:
if dataframe[col].dtype in ["float64", "int64"]:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan).ffill().fillna(dataframe[col].median())
logger.info(f"Feature engineering completed, features: {len(dataframe.columns)}")
return dataframe
def feature_engineering_expand_basic(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
"""Basic feature engineering for FreqAI."""
dataframe["%-pct-change"] = dataframe["close"].pct_change()
dataframe["%-raw_volume"] = dataframe["volume"]
dataframe["%-raw_price"] = dataframe["close"]
for col in dataframe.columns:
if dataframe[col].dtype in ["float64", "int64"]:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan).ffill().fillna(dataframe[col].median())
return dataframe
def feature_engineering_standard(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
"""Standard feature engineering for temporal features."""
dataframe["%-day_of_week"] = dataframe["date"].dt.dayofweek
dataframe["%-hour_of_day"] = dataframe["date"].dt.hour
for col in dataframe.columns:
if dataframe[col].dtype in ["float64", "int64"]:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan).ffill().fillna(dataframe[col].median())
return dataframe
def set_freqai_targets(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
"""Set FreqAI prediction targets with standardized and transformed values."""
logger.info(f"Setting FreqAI targets for pair: {metadata['pair']}")
if "close" not in dataframe.columns:
logger.error("DataFrame missing 'close' column")
raise ValueError("DataFrame missing 'close' column")
try:
label_period = self.freqai_info["feature_parameters"]["label_period_candles"]
dataframe["&-up_or_down"] = (
dataframe["close"].shift(-label_period) - dataframe["close"]
) / dataframe["close"]
# Standardize target values
dataframe["&-up_or_down"] = (
dataframe["&-up_or_down"] - dataframe["&-up_or_down"].mean()
) / dataframe["&-up_or_down"].std()
# Apply logarithmic transformation to amplify signals
dataframe["&-up_or_down"] = np.log1p(dataframe["&-up_or_down"].abs()) * np.sign(dataframe["&-up_or_down"])
dataframe["&-up_or_down"] = dataframe["&-up_or_down"].replace([np.inf, -np.inf], np.nan).ffill().fillna(0)
dataframe["&-buy_rsi"] = ta.RSI(dataframe, timeperiod=14)
dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std()
for col in ["&-buy_rsi", "&-up_or_down", "%-volatility"]:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan).ffill().fillna(dataframe[col].median())
except Exception as e:
logger.error(f"Failed to create FreqAI targets: {str(e)}")
raise
logger.info(f"Target column shape: {dataframe['&-up_or_down'].shape}")
logger.info(f"Target preview:\n{dataframe[['&-up_or_down', '&-buy_rsi']].head().to_string()}")
return dataframe
def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
"""Populate indicators and dynamic strategy parameters."""
logger.info(f"Processing pair: {metadata['pair']}")
dataframe = self.freqai.start(dataframe, metadata, self)
# Calculate technical indicators
dataframe["rsi"] = ta.RSI(dataframe, timeperiod=14)
dataframe["atr"] = ta.ATR(dataframe["high"], dataframe["low"], dataframe["close"], timeperiod=14)
bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=20, stds=2)
dataframe["bb_lowerband"] = bollinger["lower"]
dataframe["bb_middleband"] = bollinger["mid"]
dataframe["bb_upperband"] = bollinger["upper"]
dataframe = self.calculate_macd(dataframe)
# Dynamic parameter settings
if "&-buy_rsi" in dataframe.columns:
dataframe["&-sell_rsi"] = dataframe["&-buy_rsi"] + 20
dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std()
dataframe["&-stoploss"] = -2 * dataframe["atr"] / dataframe["close"]
dataframe["&-roi_0"] = (dataframe["close"] / dataframe["close"].shift(10) - 1).clip(0, 0.1).fillna(0)
# Dynamic predictions
dataframe["buy_rsi_pred"] = dataframe["rsi"].rolling(window=10).mean().clip(20, 50).fillna(dataframe["rsi"].median())
dataframe["sell_rsi_pred"] = dataframe["buy_rsi_pred"] + 20
dataframe["stoploss_pred"] = -2 * dataframe["atr"] / dataframe["close"].clip(-0.2, -0.05)
dataframe["roi_0_pred"] = dataframe["&-roi_0"].clip(0.01, 0.1).fillna(dataframe["&-roi_0"].mean())
for col in ["&-stoploss", "&-roi_0", "buy_rsi_pred", "sell_rsi_pred", "stoploss_pred", "roi_0_pred"]:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan).ffill().fillna(dataframe[col].mean())
# Set strategy parameters
self.buy_rsi.value = float(dataframe["buy_rsi_pred"].iloc[-1])
self.sell_rsi.value = float(dataframe["sell_rsi_pred"].iloc[-1])
self.stoploss = float(dataframe["stoploss_pred"].iloc[-1])
self.minimal_roi = {
0: float(self.roi_0.value) + (dataframe["atr"].iloc[-1] / dataframe["close"].iloc[-1]),
15: float(self.roi_15.value) * 0.8,
30: float(self.roi_30.value) * 0.5,
60: 0
}
# Dynamic trailing stop
self.trailing_stop_positive = float(1.5 * dataframe["atr"].iloc[-1] / dataframe["close"].iloc[-1])
self.trailing_stop_positive_offset = float(2.5 * dataframe["atr"].iloc[-1] / dataframe["close"].iloc[-1])
# Data cleaning
dataframe.ffill(inplace=True)
dataframe.fillna(dataframe.mean(numeric_only=True), inplace=True)
logger.info(f"&-up_or_down stats:\n{dataframe['&-up_or_down'].describe().to_string()}")
return dataframe
def populate_entry_trend(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
"""Generate entry signals with relaxed conditions."""
# Validate required columns
required_cols = ["rsi", "buy_rsi_pred", "volume", "bb_middleband", "macd", "macdsignal", "&-up_or_down"]
if not all(col in dataframe.columns for col in required_cols):
logger.error(f"Missing required columns: {set(required_cols) - set(dataframe.columns)}")
return dataframe
dataframe = self.calculate_macd(dataframe)
enter_long_conditions = [
dataframe["rsi"] < dataframe["buy_rsi_pred"],
dataframe["volume"] > dataframe["volume"].rolling(window=10).mean() * 1.0,
dataframe["close"] > dataframe["bb_middleband"],
dataframe["macd"] > dataframe["macdsignal"],
dataframe["&-up_or_down"] > (0.001 if metadata["pair"] == "BTC/USDT" else 0.0015) # Lower threshold for BTC
]
# Combine conditions into a DataFrame
conditions_df = pd.concat(enter_long_conditions, axis=1)
dataframe.loc[conditions_df.sum(axis=1) >= 3, ["enter_long", "enter_tag"]] = (1, "long")
return dataframe
def populate_exit_trend(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
"""Generate exit signals with confirmation."""
# Validate required columns
required_cols = ["rsi", "sell_rsi_pred", "bb_middleband", "macd", "macdsignal", "&-up_or_down"]
if not all(col in dataframe.columns for col in required_cols):
logger.error(f"Missing required columns: {set(required_cols) - set(dataframe.columns)}")
return dataframe
exit_long_conditions = [
dataframe["rsi"] > dataframe["sell_rsi_pred"],
dataframe["close"] < dataframe["bb_middleband"],
dataframe["macd"] < dataframe["macdsignal"],
dataframe["&-up_or_down"] < -0.005
]
# Combine conditions into a DataFrame
conditions_df = pd.concat(exit_long_conditions, axis=1)
# Require confirmation: at least 3 conditions met for two consecutive candles
exit_signal = (conditions_df.sum(axis=1) >= 3) & (conditions_df.shift(1).sum(axis=1) >= 3)
dataframe.loc[exit_signal, "exit_long"] = 1
return dataframe
def confirm_trade_entry(
self, pair: str, order_type: str, amount: float, rate: float,
time_in_force: str, current_time, entry_tag, side: str, **kwargs
) -> bool:
"""Confirm trade entry to avoid slippage."""
df, _ = self.dp.get_analyzed_dataframe(pair, self.timeframe)
last_candle = df.iloc[-1].squeeze()
if side == "long":
if rate > (last_candle["close"] * (1 + 0.0025)):
return False
return True