import logging import numpy as np import pandas as pd from functools import reduce import talib.abstract as ta from pandas import DataFrame from technical import qtpylib from freqtrade.strategy import IStrategy, IntParameter, DecimalParameter logger = logging.getLogger(__name__) class FreqaiExampleStrategy(IStrategy): minimal_roi = { "0": 0.02, "7": 0.01, "13": 0.005, "60": 0 } stoploss = 0.0 trailing_stop = True process_only_new_candles = True use_exit_signal = True startup_candle_count: int = 40 can_short = False buy_rsi = IntParameter(low=10, high=50, default=27, space="buy", optimize=True, load=True) sell_rsi = IntParameter(low=50, high=90, default=59, space="sell", optimize=True, load=True) roi_0 = DecimalParameter(low=0.01, high=0.2, default=0.038, space="roi", optimize=True, load=True) roi_15 = DecimalParameter(low=0.005, high=0.1, default=0.027, space="roi", optimize=True, load=True) roi_30 = DecimalParameter(low=0.001, high=0.05, default=0.009, space="roi", optimize=True, load=True) stoploss_param = DecimalParameter(low=-0.25, high=-0.05, default=-0.1, space="stoploss", optimize=True, load=True) trailing_stop_positive_offset = DecimalParameter(low=0.005, high=0.5, default=0.01, space="trailing", optimize=True, load=True) protections = [] freqai_info = { "model": "LightGBMRegressor", "feature_parameters": { "include_timeframes": ["5m"], "include_corr_pairlist": ["SOL/USDT"], "label_period_candles": 12, "include_shifted_candles": 0, "include_periods": [20], "DI_threshold": 5.0 }, "data_split_parameters": { "test_size": 0.2, "shuffle": False, }, "model_training_parameters": { "n_estimators": 100, "learning_rate": 0.1, "num_leaves": 15, "n_jobs": 4, "verbosity": -1 }, } plot_config = { "main_plot": { "close": {"color": "blue"}, "bb_lowerband": {"color": "purple"} }, "subplots": { "&-buy_rsi": {"&-buy_rsi": {"color": "green"}}, "&-sell_rsi": {"&-sell_rsi": {"color": "red"}}, "rsi": {"rsi": {"color": "black"}}, "do_predict": {"do_predict": {"color": "brown"}}, "trade_signals": { "enter_long": {"color": "green", "type": "scatter"}, "exit_long": {"color": "red", "type": "scatter"} } } } def feature_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: dict, **kwargs) -> DataFrame: dataframe["%-rsi-period"] = ta.RSI(dataframe, timeperiod=period) bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=period, stds=2.2) dataframe["%-bb_width-period"] = (bollinger["upper"] - bollinger["lower"]) / bollinger["mid"] dataframe = dataframe.replace([np.inf, -np.inf], 0).ffill().fillna(0) return dataframe def feature_engineering_expand_basic(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame: dataframe["%-pct-change"] = dataframe["close"].pct_change() dataframe = dataframe.replace([np.inf, -np.inf], 0).ffill().fillna(0) return dataframe def feature_engineering_standard(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame: if len(dataframe) < 20 or dataframe["close"].isna().any(): logger.warning(f"DataFrame too short ({len(dataframe)} rows) or contains NaN in close, cannot compute %-volatility") dataframe["%-volatility"] = 0 else: dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std() dataframe["%-volatility"] = dataframe["%-volatility"].fillna(0) if dataframe["%-volatility"].std() > 0: dataframe["%-volatility"] = (dataframe["%-volatility"] - dataframe["%-volatility"].mean()) / dataframe["%-volatility"].std() dataframe["%-day_of_week"] = dataframe["date"].dt.dayofweek dataframe["%-hour_of_day"] = dataframe["date"].dt.hour dataframe = dataframe.replace([np.inf, -np.inf], 0).ffill().fillna(0) return dataframe def set_freqai_targets(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame: logger.info(f"Setting FreqAI targets for pair: {metadata['pair']}") if "close" not in dataframe.columns: logger.error("DataFrame missing required 'close' column") raise ValueError("DataFrame missing required 'close' column") try: label_period = self.freqai_info["feature_parameters"]["label_period_candles"] if len(dataframe) < 20 or dataframe["close"].isna().any(): logger.warning(f"DataFrame too short ({len(dataframe)} rows) or contains NaN in close, cannot compute %-volatility") dataframe["%-volatility"] = 0 else: dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std() dataframe["%-volatility"] = dataframe["%-volatility"].fillna(0) if dataframe["%-volatility"].std() > 0: dataframe["%-volatility"] = (dataframe["%-volatility"] - dataframe["%-volatility"].mean()) / dataframe["%-volatility"].std() dataframe["&-buy_rsi"] = ta.RSI(dataframe, timeperiod=14).shift(-label_period) for col in ["&-buy_rsi", "%-volatility"]: dataframe[col] = dataframe[col].replace([np.inf, -np.inf], 0) dataframe[col] = dataframe[col].ffill().fillna(0) if dataframe[col].isna().any(): logger.warning(f"Target column {col} still contains NaN, check data generation logic") except Exception as e: logger.error(f"Failed to create FreqAI targets: {str(e)}") raise logger.info(f"Target columns preview: {dataframe[['&-buy_rsi']].head().to_string()}") return dataframe def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame: logger.info(f"Processing pair: {metadata['pair']}") logger.info(f"DataFrame rows: {len(dataframe)}") logger.info(f"Columns before freqai.start: {list(dataframe.columns)}") if "close" not in dataframe.columns or dataframe["close"].isna().all(): logger.error(f"DataFrame missing 'close' column or all NaN for pair: {metadata['pair']}") raise ValueError("DataFrame missing valid 'close' column") if len(dataframe) < 14: logger.warning(f"DataFrame too short ({len(dataframe)} rows), cannot compute rsi") dataframe["rsi"] = 50 else: dataframe["rsi"] = ta.RSI(dataframe, timeperiod=14) logger.info(f"rsi stats: {dataframe['rsi'].describe().to_string()}") if len(dataframe) < 20 or dataframe["close"].isna().any(): logger.warning(f"DataFrame too short ({len(dataframe)} rows) or contains NaN in close, cannot compute %-volatility") dataframe["%-volatility"] = 0 else: dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std() dataframe["%-volatility"] = dataframe["%-volatility"].fillna(0) if dataframe["%-volatility"].std() > 0: dataframe["%-volatility"] = (dataframe["%-volatility"] - dataframe["%-volatility"].mean()) / dataframe["%-volatility"].std() logger.info(f"%-volatility stats: {dataframe['%-volatility'].describe().to_string()}") if len(dataframe) < 9: logger.warning(f"DataFrame too short ({len(dataframe)} rows), cannot compute tema") dataframe["tema"] = dataframe["close"] else: dataframe["tema"] = ta.TEMA(dataframe, timeperiod=9) if dataframe["tema"].isna().any(): logger.warning("tema contains NaN, filling with close") dataframe["tema"] = dataframe["tema"].fillna(dataframe["close"]) logger.info(f"tema stats: {dataframe['tema'].describe().to_string()}") if len(dataframe) < 20: logger.warning(f"DataFrame too short ({len(dataframe)} rows), cannot compute bb_lowerband") dataframe["bb_lowerband"] = dataframe["close"] else: bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=20, stds=2.2) dataframe["bb_lowerband"] = bollinger["lower"] if dataframe["bb_lowerband"].isna().any(): logger.warning("bb_lowerband contains NaN, filling with close") dataframe["bb_lowerband"] = dataframe["bb_lowerband"].fillna(dataframe["close"]) logger.info(f"bb_lowerband stats: {dataframe['bb_lowerband'].describe().to_string()}") # 生成 up_or_down label_period = self.freqai_info["feature_parameters"]["label_period_candles"] if len(dataframe) < label_period + 1: logger.warning(f"DataFrame too short ({len(dataframe)} rows), cannot compute up_or_down") dataframe["up_or_down"] = 0 else: dataframe["up_or_down"] = np.where( dataframe["close"].shift(-label_period) > dataframe["close"], 1, 0 ) if dataframe["up_or_down"].isna().any(): logger.warning("up_or_down contains NaN, filling with 0") dataframe["up_or_down"] = dataframe["up_or_down"].fillna(0) logger.info(f"up_or_down stats: {dataframe['up_or_down'].describe().to_string()}") if "date" in dataframe.columns: dataframe["%-day_of_week"] = dataframe["date"].dt.dayofweek dataframe["%-hour_of_day"] = dataframe["date"].dt.hour else: logger.warning("Missing 'date' column, skipping %-day_of_week and %-hour_of_day") dataframe["%-day_of_week"] = 0 dataframe["%-hour_of_day"] = 0 try: dataframe = self.freqai.start(dataframe, metadata, self) logger.info(f"Columns after freqai.start: {list(dataframe.columns)}") except Exception as e: logger.error(f"freqai.start failed: {str(e)}") dataframe["buy_rsi_pred"] = 50 dataframe["sell_rsi_pred"] = 80 dataframe["do_predict"] = 1 for col in ["buy_rsi_pred", "sell_rsi_pred"]: if col not in dataframe.columns: logger.error(f"Error: {col} column not generated for pair: {metadata['pair']}") dataframe[col] = 50 if col == "buy_rsi_pred" else 80 logger.info(f"{col} stats: {dataframe[col].describe().to_string()}") # 调试特征分布 if "%-bb_width-period_10_SOL/USDT_5m" in dataframe.columns: if dataframe["%-bb_width-period_10_SOL/USDT_5m"].std() > 0: dataframe["%-bb_width-period_10_SOL/USDT_5m"] = ( dataframe["%-bb_width-period_10_SOL/USDT_5m"] - dataframe["%-bb_width-period_10_SOL/USDT_5m"].mean() ) / dataframe["%-bb_width-period_10_SOL/USDT_5m"].std() logger.info(f"%-bb_width-period_10 stats: {dataframe['%-bb_width-period_10_SOL/USDT_5m'].describe().to_string()}") def get_expected_columns(freqai_config: dict) -> list: indicators = ["rsi", "bb_width", "pct-change"] periods = freqai_config.get("feature_parameters", {}).get("include_periods", [10, 20]) pairs = freqai_config.get("include_corr_pairlist", ["SOL/USDT", "BTC/USDT"]) timeframes = freqai_config.get("include_timeframes", ["5m"]) shifts = [0] expected_columns = ["%-volatility", "%-day_of_week", "%-hour_of_day"] for indicator in indicators: for period in periods: for pair in pairs: for timeframe in timeframes: for shift in shifts: col_name = f"%-{indicator}-period_{period}" if indicator != "pct-change" else f"%-{indicator}" if shift > 0: col_name += f"_shift-{shift}" col_name += f"_{pair}_{timeframe}" expected_columns.append(col_name) return expected_columns expected_columns = get_expected_columns(self.freqai_info) logger.info(f"Expected feature columns ({len(expected_columns)}): {expected_columns[:10]}...") actual_columns = list(dataframe.columns) missing_columns = [col for col in expected_columns if col not in actual_columns] extra_columns = [col for col in actual_columns if col not in expected_columns and col.startswith("%-")] logger.info(f"Missing columns ({len(missing_columns)}): {missing_columns}") logger.info(f"Extra columns ({len(extra_columns)}): {extra_columns}") if "DI_values" in dataframe.columns: logger.info(f"DI_values stats: {dataframe['DI_values'].describe().to_string()}") logger.info(f"DI discarded predictions: {len(dataframe[dataframe['do_predict'] == 0])}") dataframe = dataframe.replace([np.inf, -np.inf], 0).ffill().fillna(0) logger.info(f"Final columns in populate_indicators: {list(dataframe.columns)}") return dataframe def populate_entry_trend(self, df: DataFrame, metadata: dict) -> DataFrame: enter_long_conditions = [ qtpylib.crossed_above(df["rsi"], df["buy_rsi_pred"] + (5 if metadata["pair"] == "BTC/USDT" else 0)), df["tema"] > df["tema"].shift(1), df["volume"] > 0, df["do_predict"] == 1, df["up_or_down"] == 1 ] if enter_long_conditions: df.loc[ reduce(lambda x, y: x & y, enter_long_conditions), ["enter_long", "enter_tag"] ] = (1, "long") df["entry_signal"] = reduce(lambda x, y: x & y, enter_long_conditions) df["entry_signal"] = df["entry_signal"].rolling(window=2, min_periods=1).max().astype(bool) df.loc[ df["entry_signal"], ["enter_long", "enter_tag"] ] = (1, "long") if df["entry_signal"].iloc[-1]: logger.info(f"Entry signal triggered for {metadata['pair']}: rsi={df['rsi'].iloc[-1]}, buy_rsi_pred={df['buy_rsi_pred'].iloc[-1]}, do_predict={df['do_predict'].iloc[-1]}") return df def populate_exit_trend(self, df: DataFrame, metadata: dict) -> DataFrame: exit_long_conditions = [ (qtpylib.crossed_above(df["rsi"], df["sell_rsi_pred"])) | (df["close"] < df["close"].shift(1) * 0.98) | (df["close"] < df["bb_lowerband"]), df["volume"] > 0, df["do_predict"] == 1, df["up_or_down"] == 0 ] time_exit = (df["date"] >= df["date"].shift(1) + pd.Timedelta(days=1)) df.loc[ (reduce(lambda x, y: x & y, exit_long_conditions)) | time_exit, "exit_long" ] = 1 return df def confirm_trade_entry( self, pair: str, order_type: str, amount: float, rate: float, time_in_force: str, current_time, entry_tag, side: str, **kwargs ) -> bool: logger.info(f"Confirming trade entry for {pair}, order_type: {order_type}, rate: {rate}, current_time: {current_time}") df, _ = self.dp.get_analyzed_dataframe(pair, self.timeframe) last_candle = df.iloc[-1].squeeze() if side == "long": if order_type == "market": logger.info(f"Order confirmed for {pair}, rate: {rate} (market order)") return True if rate <= (last_candle["close"] * (1 + 0.01)): logger.info(f"Order confirmed for {pair}, rate: {rate}") return True logger.info(f"Order rejected: rate {rate} exceeds threshold {last_candle['close'] * 1.01}") return False return True