myTestFreqAI/freqtrade/templates/FreqaiExampleStrategy.py

313 lines
16 KiB
Python

import logging
import numpy as np
import pandas as pd
from functools import reduce
import talib.abstract as ta
from pandas import DataFrame
from technical import qtpylib
from freqtrade.strategy import IStrategy, IntParameter, DecimalParameter
logger = logging.getLogger(__name__)
class FreqaiExampleStrategy(IStrategy):
minimal_roi = {
"0": 0.02,
"7": 0.01,
"13": 0.005,
"60": 0
}
stoploss = 0.0
trailing_stop = True
process_only_new_candles = True
use_exit_signal = True
startup_candle_count: int = 40
can_short = False
buy_rsi = IntParameter(low=10, high=50, default=27, space="buy", optimize=True, load=True)
sell_rsi = IntParameter(low=50, high=90, default=59, space="sell", optimize=True, load=True)
roi_0 = DecimalParameter(low=0.01, high=0.2, default=0.038, space="roi", optimize=True, load=True)
roi_15 = DecimalParameter(low=0.005, high=0.1, default=0.027, space="roi", optimize=True, load=True)
roi_30 = DecimalParameter(low=0.001, high=0.05, default=0.009, space="roi", optimize=True, load=True)
stoploss_param = DecimalParameter(low=-0.25, high=-0.05, default=-0.1, space="stoploss", optimize=True, load=True)
trailing_stop_positive_offset = DecimalParameter(low=0.005, high=0.5, default=0.01, space="trailing", optimize=True, load=True)
protections = []
freqai_info = {
"model": "LightGBMRegressor",
"feature_parameters": {
"include_timeframes": ["5m"],
"include_corr_pairlist": ["SOL/USDT"],
"label_period_candles": 12,
"include_shifted_candles": 0,
"include_periods": [20],
"DI_threshold": 5.0
},
"data_split_parameters": {
"test_size": 0.2,
"shuffle": False,
},
"model_training_parameters": {
"n_estimators": 100,
"learning_rate": 0.1,
"num_leaves": 15,
"n_jobs": 4,
"verbosity": -1
},
}
plot_config = {
"main_plot": {
"close": {"color": "blue"},
"bb_lowerband": {"color": "purple"}
},
"subplots": {
"&-buy_rsi": {"&-buy_rsi": {"color": "green"}},
"&-sell_rsi": {"&-sell_rsi": {"color": "red"}},
"rsi": {"rsi": {"color": "black"}},
"do_predict": {"do_predict": {"color": "brown"}},
"trade_signals": {
"enter_long": {"color": "green", "type": "scatter"},
"exit_long": {"color": "red", "type": "scatter"}
}
}
}
def feature_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: dict, **kwargs) -> DataFrame:
dataframe["%-rsi-period"] = ta.RSI(dataframe, timeperiod=period)
bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=period, stds=2.2)
dataframe["%-bb_width-period"] = (bollinger["upper"] - bollinger["lower"]) / bollinger["mid"]
dataframe = dataframe.replace([np.inf, -np.inf], 0).ffill().fillna(0)
return dataframe
def feature_engineering_expand_basic(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
dataframe["%-pct-change"] = dataframe["close"].pct_change()
dataframe = dataframe.replace([np.inf, -np.inf], 0).ffill().fillna(0)
return dataframe
def feature_engineering_standard(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
if len(dataframe) < 20 or dataframe["close"].isna().any():
logger.warning(f"DataFrame too short ({len(dataframe)} rows) or contains NaN in close, cannot compute %-volatility")
dataframe["%-volatility"] = 0
else:
dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std()
dataframe["%-volatility"] = dataframe["%-volatility"].fillna(0)
if dataframe["%-volatility"].std() > 0:
dataframe["%-volatility"] = (dataframe["%-volatility"] - dataframe["%-volatility"].mean()) / dataframe["%-volatility"].std()
dataframe["%-day_of_week"] = dataframe["date"].dt.dayofweek
dataframe["%-hour_of_day"] = dataframe["date"].dt.hour
dataframe = dataframe.replace([np.inf, -np.inf], 0).ffill().fillna(0)
return dataframe
def set_freqai_targets(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
logger.info(f"Setting FreqAI targets for pair: {metadata['pair']}")
if "close" not in dataframe.columns:
logger.error("DataFrame missing required 'close' column")
raise ValueError("DataFrame missing required 'close' column")
try:
label_period = self.freqai_info["feature_parameters"]["label_period_candles"]
if len(dataframe) < 20 or dataframe["close"].isna().any():
logger.warning(f"DataFrame too short ({len(dataframe)} rows) or contains NaN in close, cannot compute %-volatility")
dataframe["%-volatility"] = 0
else:
dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std()
dataframe["%-volatility"] = dataframe["%-volatility"].fillna(0)
if dataframe["%-volatility"].std() > 0:
dataframe["%-volatility"] = (dataframe["%-volatility"] - dataframe["%-volatility"].mean()) / dataframe["%-volatility"].std()
dataframe["&-buy_rsi"] = ta.RSI(dataframe, timeperiod=14).shift(-label_period)
for col in ["&-buy_rsi", "%-volatility"]:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], 0)
dataframe[col] = dataframe[col].ffill().fillna(0)
if dataframe[col].isna().any():
logger.warning(f"Target column {col} still contains NaN, check data generation logic")
except Exception as e:
logger.error(f"Failed to create FreqAI targets: {str(e)}")
raise
logger.info(f"Target columns preview: {dataframe[['&-buy_rsi']].head().to_string()}")
return dataframe
def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
logger.info(f"Processing pair: {metadata['pair']}")
logger.info(f"DataFrame rows: {len(dataframe)}")
logger.info(f"Columns before freqai.start: {list(dataframe.columns)}")
if "close" not in dataframe.columns or dataframe["close"].isna().all():
logger.error(f"DataFrame missing 'close' column or all NaN for pair: {metadata['pair']}")
raise ValueError("DataFrame missing valid 'close' column")
if len(dataframe) < 14:
logger.warning(f"DataFrame too short ({len(dataframe)} rows), cannot compute rsi")
dataframe["rsi"] = 50
else:
dataframe["rsi"] = ta.RSI(dataframe, timeperiod=14)
logger.info(f"rsi stats: {dataframe['rsi'].describe().to_string()}")
if len(dataframe) < 20 or dataframe["close"].isna().any():
logger.warning(f"DataFrame too short ({len(dataframe)} rows) or contains NaN in close, cannot compute %-volatility")
dataframe["%-volatility"] = 0
else:
dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std()
dataframe["%-volatility"] = dataframe["%-volatility"].fillna(0)
if dataframe["%-volatility"].std() > 0:
dataframe["%-volatility"] = (dataframe["%-volatility"] - dataframe["%-volatility"].mean()) / dataframe["%-volatility"].std()
logger.info(f"%-volatility stats: {dataframe['%-volatility'].describe().to_string()}")
if len(dataframe) < 9:
logger.warning(f"DataFrame too short ({len(dataframe)} rows), cannot compute tema")
dataframe["tema"] = dataframe["close"]
else:
dataframe["tema"] = ta.TEMA(dataframe, timeperiod=9)
if dataframe["tema"].isna().any():
logger.warning("tema contains NaN, filling with close")
dataframe["tema"] = dataframe["tema"].fillna(dataframe["close"])
logger.info(f"tema stats: {dataframe['tema'].describe().to_string()}")
if len(dataframe) < 20:
logger.warning(f"DataFrame too short ({len(dataframe)} rows), cannot compute bb_lowerband")
dataframe["bb_lowerband"] = dataframe["close"]
else:
bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=20, stds=2.2)
dataframe["bb_lowerband"] = bollinger["lower"]
if dataframe["bb_lowerband"].isna().any():
logger.warning("bb_lowerband contains NaN, filling with close")
dataframe["bb_lowerband"] = dataframe["bb_lowerband"].fillna(dataframe["close"])
logger.info(f"bb_lowerband stats: {dataframe['bb_lowerband'].describe().to_string()}")
# 生成 up_or_down
label_period = self.freqai_info["feature_parameters"]["label_period_candles"]
if len(dataframe) < label_period + 1:
logger.warning(f"DataFrame too short ({len(dataframe)} rows), cannot compute up_or_down")
dataframe["up_or_down"] = 0
else:
dataframe["up_or_down"] = np.where(
dataframe["close"].shift(-label_period) > dataframe["close"], 1, 0
)
if dataframe["up_or_down"].isna().any():
logger.warning("up_or_down contains NaN, filling with 0")
dataframe["up_or_down"] = dataframe["up_or_down"].fillna(0)
logger.info(f"up_or_down stats: {dataframe['up_or_down'].describe().to_string()}")
if "date" in dataframe.columns:
dataframe["%-day_of_week"] = dataframe["date"].dt.dayofweek
dataframe["%-hour_of_day"] = dataframe["date"].dt.hour
else:
logger.warning("Missing 'date' column, skipping %-day_of_week and %-hour_of_day")
dataframe["%-day_of_week"] = 0
dataframe["%-hour_of_day"] = 0
try:
dataframe = self.freqai.start(dataframe, metadata, self)
logger.info(f"Columns after freqai.start: {list(dataframe.columns)}")
except Exception as e:
logger.error(f"freqai.start failed: {str(e)}")
dataframe["buy_rsi_pred"] = 50
dataframe["sell_rsi_pred"] = 80
dataframe["do_predict"] = 1
for col in ["buy_rsi_pred", "sell_rsi_pred"]:
if col not in dataframe.columns:
logger.error(f"Error: {col} column not generated for pair: {metadata['pair']}")
dataframe[col] = 50 if col == "buy_rsi_pred" else 80
logger.info(f"{col} stats: {dataframe[col].describe().to_string()}")
# 调试特征分布
if "%-bb_width-period_10_SOL/USDT_5m" in dataframe.columns:
if dataframe["%-bb_width-period_10_SOL/USDT_5m"].std() > 0:
dataframe["%-bb_width-period_10_SOL/USDT_5m"] = (
dataframe["%-bb_width-period_10_SOL/USDT_5m"] - dataframe["%-bb_width-period_10_SOL/USDT_5m"].mean()
) / dataframe["%-bb_width-period_10_SOL/USDT_5m"].std()
logger.info(f"%-bb_width-period_10 stats: {dataframe['%-bb_width-period_10_SOL/USDT_5m'].describe().to_string()}")
def get_expected_columns(freqai_config: dict) -> list:
indicators = ["rsi", "bb_width", "pct-change"]
periods = freqai_config.get("feature_parameters", {}).get("include_periods", [10, 20])
pairs = freqai_config.get("include_corr_pairlist", ["SOL/USDT", "BTC/USDT"])
timeframes = freqai_config.get("include_timeframes", ["5m"])
shifts = [0]
expected_columns = ["%-volatility", "%-day_of_week", "%-hour_of_day"]
for indicator in indicators:
for period in periods:
for pair in pairs:
for timeframe in timeframes:
for shift in shifts:
col_name = f"%-{indicator}-period_{period}" if indicator != "pct-change" else f"%-{indicator}"
if shift > 0:
col_name += f"_shift-{shift}"
col_name += f"_{pair}_{timeframe}"
expected_columns.append(col_name)
return expected_columns
expected_columns = get_expected_columns(self.freqai_info)
logger.info(f"Expected feature columns ({len(expected_columns)}): {expected_columns[:10]}...")
actual_columns = list(dataframe.columns)
missing_columns = [col for col in expected_columns if col not in actual_columns]
extra_columns = [col for col in actual_columns if col not in expected_columns and col.startswith("%-")]
logger.info(f"Missing columns ({len(missing_columns)}): {missing_columns}")
logger.info(f"Extra columns ({len(extra_columns)}): {extra_columns}")
if "DI_values" in dataframe.columns:
logger.info(f"DI_values stats: {dataframe['DI_values'].describe().to_string()}")
logger.info(f"DI discarded predictions: {len(dataframe[dataframe['do_predict'] == 0])}")
dataframe = dataframe.replace([np.inf, -np.inf], 0).ffill().fillna(0)
logger.info(f"Final columns in populate_indicators: {list(dataframe.columns)}")
return dataframe
def populate_entry_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
enter_long_conditions = [
qtpylib.crossed_above(df["rsi"], df["buy_rsi_pred"] + (5 if metadata["pair"] == "BTC/USDT" else 0)),
df["tema"] > df["tema"].shift(1),
df["volume"] > 0,
df["do_predict"] == 1,
df["up_or_down"] == 1
]
if enter_long_conditions:
df.loc[
reduce(lambda x, y: x & y, enter_long_conditions),
["enter_long", "enter_tag"]
] = (1, "long")
df["entry_signal"] = reduce(lambda x, y: x & y, enter_long_conditions)
df["entry_signal"] = df["entry_signal"].rolling(window=2, min_periods=1).max().astype(bool)
df.loc[
df["entry_signal"],
["enter_long", "enter_tag"]
] = (1, "long")
if df["entry_signal"].iloc[-1]:
logger.info(f"Entry signal triggered for {metadata['pair']}: rsi={df['rsi'].iloc[-1]}, buy_rsi_pred={df['buy_rsi_pred'].iloc[-1]}, do_predict={df['do_predict'].iloc[-1]}")
return df
def populate_exit_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
exit_long_conditions = [
(qtpylib.crossed_above(df["rsi"], df["sell_rsi_pred"])) |
(df["close"] < df["close"].shift(1) * 0.98) |
(df["close"] < df["bb_lowerband"]),
df["volume"] > 0,
df["do_predict"] == 1,
df["up_or_down"] == 0
]
time_exit = (df["date"] >= df["date"].shift(1) + pd.Timedelta(days=1))
df.loc[
(reduce(lambda x, y: x & y, exit_long_conditions)) | time_exit,
"exit_long"
] = 1
return df
def confirm_trade_entry(
self, pair: str, order_type: str, amount: float, rate: float,
time_in_force: str, current_time, entry_tag, side: str, **kwargs
) -> bool:
logger.info(f"Confirming trade entry for {pair}, order_type: {order_type}, rate: {rate}, current_time: {current_time}")
df, _ = self.dp.get_analyzed_dataframe(pair, self.timeframe)
last_candle = df.iloc[-1].squeeze()
if side == "long":
if order_type == "market":
logger.info(f"Order confirmed for {pair}, rate: {rate} (market order)")
return True
if rate <= (last_candle["close"] * (1 + 0.01)):
logger.info(f"Order confirmed for {pair}, rate: {rate}")
return True
logger.info(f"Order rejected: rate {rate} exceeds threshold {last_candle['close'] * 1.01}")
return False
return True