This commit is contained in:
zhangkun9038@dingtalk.com 2025-04-29 17:08:41 +08:00
parent 0d914c2b76
commit fdc1fa0374
4 changed files with 1 additions and 666 deletions

View File

@ -1,297 +0,0 @@
import logging
import numpy as np # noqa
import pandas as pd # noqa
import talib.abstract as ta
from pandas import DataFrame
from technical import qtpylib
from freqtrade.strategy import IntParameter, IStrategy, merge_informative_pair
logger = logging.getLogger(__name__)
class FreqaiExampleHybridStrategy(IStrategy):
"""
Example of a hybrid FreqAI strat, designed to illustrate how a user may employ
FreqAI to bolster a typical Freqtrade strategy.
Launching this strategy would be:
freqtrade trade --strategy FreqaiExampleHybridStrategy --strategy-path freqtrade/templates
--freqaimodel CatboostClassifier --config config_examples/config_freqai.example.json
or the user simply adds this to their config:
"freqai": {
"enabled": true,
"purge_old_models": 2,
"train_period_days": 15,
"identifier": "unique-id",
"feature_parameters": {
"include_timeframes": [
"3m",
"15m",
"1h"
],
"include_corr_pairlist": [
"BTC/USDT",
"ETH/USDT"
],
"label_period_candles": 20,
"include_shifted_candles": 2,
"DI_threshold": 0.9,
"weight_factor": 0.9,
"principal_component_analysis": false,
"use_SVM_to_remove_outliers": true,
"indicator_periods_candles": [10, 20]
},
"data_split_parameters": {
"test_size": 0,
"random_state": 1
},
"model_training_parameters": {
"n_estimators": 200,
"max_depth": 5,
"learning_rate": 0.05
}
},
Thanks to @smarmau and @johanvulgt for developing and sharing the strategy.
"""
minimal_roi = {
# "120": 0.0, # exit after 120 minutes at break even
"60": 0.01,
"30": 0.02,
"0": 0.04,
}
plot_config = {
"main_plot": {
"tema": {},
},
"subplots": {
"MACD": {
"macd": {"color": "blue"},
"macdsignal": {"color": "orange"},
},
"RSI": {
"rsi": {"color": "red"},
},
"Up_or_down": {
"&s-up_or_down": {"color": "green"},
},
},
}
process_only_new_candles = True
stoploss = -0.05
use_exit_signal = True
startup_candle_count: int = 30
can_short = False
# Hyperoptable parameters
buy_rsi = IntParameter(low=1, high=50, default=30, space="buy", optimize=True, load=True)
sell_rsi = IntParameter(low=50, high=100, default=70, space="sell", optimize=True, load=True)
def feature_engineering_expand_all(
self, dataframe: DataFrame, period: int, metadata: dict, **kwargs
) -> DataFrame:
"""
*Only functional with FreqAI enabled strategies*
This function will automatically expand the defined features on the config defined
`indicator_periods_candles`, `include_timeframes`, `include_shifted_candles`, and
`include_corr_pairs`. In other words, a single feature defined in this function
will automatically expand to a total of
`indicator_periods_candles` * `include_timeframes` * `include_shifted_candles` *
`include_corr_pairs` numbers of features added to the model.
All features must be prepended with `%` to be recognized by FreqAI internals.
More details on how these config defined parameters accelerate feature engineering
in the documentation at:
https://www.freqtrade.io/en/latest/freqai-parameter-table/#feature-parameters
https://www.freqtrade.io/en/latest/freqai-feature-engineering/#defining-the-features
:param dataframe: strategy dataframe which will receive the features
:param period: period of the indicator - usage example:
:param metadata: metadata of current pair
dataframe["%-ema-period"] = ta.EMA(dataframe, timeperiod=period)
"""
dataframe["%-rsi-period"] = ta.RSI(dataframe, timeperiod=period)
dataframe["%-ema-period"] = ta.EMA(dataframe, timeperiod=period)
bollinger = qtpylib.bollinger_bands(
qtpylib.typical_price(dataframe), window=period, stds=2
)
dataframe["bb_lowerband-period"] = bollinger["lower"]
dataframe["bb_middleband-period"] = bollinger["mid"]
dataframe["bb_upperband-period"] = bollinger["upper"]
dataframe["%-bb_width-period"] = (
dataframe["bb_upperband-period"] - dataframe["bb_lowerband-period"]
) / dataframe["bb_middleband-period"]
return dataframe
def feature_engineering_expand_basic(
self, dataframe: DataFrame, metadata: dict, **kwargs
) -> DataFrame:
"""
*Only functional with FreqAI enabled strategies*
This function will automatically expand the defined features on the config defined
`include_timeframes`, `include_shifted_candles`, and `include_corr_pairs`.
In other words, a single feature defined in this function
will automatically expand to a total of
`include_timeframes` * `include_shifted_candles` * `include_corr_pairs`
numbers of features added to the model.
Features defined here will *not* be automatically duplicated on user defined
`indicator_periods_candles`
All features must be prepended with `%` to be recognized by FreqAI internals.
More details on how these config defined parameters accelerate feature engineering
in the documentation at:
https://www.freqtrade.io/en/latest/freqai-parameter-table/#feature-parameters
https://www.freqtrade.io/en/latest/freqai-feature-engineering/#defining-the-features
:param dataframe: strategy dataframe which will receive the features
:param metadata: metadata of current pair
dataframe["%-pct-change"] = dataframe["close"].pct_change()
dataframe["%-ema-200"] = ta.EMA(dataframe, timeperiod=200)
"""
dataframe["%-pct-change"] = dataframe["close"].pct_change()
return dataframe
def feature_engineering_standard(
self, dataframe: DataFrame, metadata: dict, **kwargs
) -> DataFrame:
"""
*Only functional with FreqAI enabled strategies*
This optional function will be called once with the dataframe of the base timeframe.
This is the final function to be called, which means that the dataframe entering this
function will contain all the features and columns created by all other
freqai_feature_engineering_* functions.
This function is a good place to do custom exotic feature extractions (e.g. tsfresh).
This function is a good place for any feature that should not be auto-expanded upon
(e.g. day of the week).
All features must be prepended with `%` to be recognized by FreqAI internals.
More details about feature engineering available:
https://www.freqtrade.io/en/latest/freqai-feature-engineering
:param dataframe: strategy dataframe which will receive the features
:param metadata: metadata of current pair
usage example: dataframe["%-day_of_week"] = (dataframe["date"].dt.dayofweek + 1) / 7
"""
dataframe["%-day_of_week"] = dataframe["date"].dt.dayofweek
dataframe["%-hour_of_day"] = dataframe["date"].dt.hour
return dataframe
def set_freqai_targets(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
"""
Redefined target variable to predict whether the price will increase or decrease in the future.
"""
logger.info(f"Setting FreqAI targets for pair: {metadata['pair']}")
if "close" not in dataframe.columns:
logger.error("Required 'close' column missing in dataframe")
raise ValueError("Required 'close' column missing in dataframe")
if len(dataframe) < 50:
logger.error(f"Insufficient data: {len(dataframe)} rows, need at least 50 for shift(-50)")
raise ValueError("Insufficient data for target calculation")
try:
# Define target variable: 1 for price increase, 0 for price decrease
dataframe["&-up_or_down"] = np.where(
dataframe["close"].shift(-50) > dataframe["close"], 1, 0
)
# Ensure target variable is a 2D array
dataframe["&-up_or_down"] = dataframe["&-up_or_down"].values.reshape(-1, 1)
except Exception as e:
logger.error(f"Failed to create &-up_or_down column: {str(e)}")
raise
logger.info("FreqAI targets set successfully")
return dataframe
def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
logger.info(f"Processing pair: {metadata['pair']}")
logger.info(f"Input DataFrame shape: {dataframe.shape}")
logger.info(f"Input DataFrame columns: {list(dataframe.columns)}")
logger.info(f"Input DataFrame head:\n{dataframe[['date', 'close', 'volume']].head().to_string()}")
# Ensure FreqAI processing
logger.info("Calling self.freqai.start")
try:
dataframe = self.freqai.start(dataframe, metadata, self)
except Exception as e:
logger.error(f"self.freqai.start failed: {str(e)}")
raise
logger.info("self.freqai.start completed")
logger.info(f"Output DataFrame shape: {dataframe.shape}")
logger.info(f"Output DataFrame columns: {list(dataframe.columns)}")
# Safely log columns that exist
available_columns = [col for col in ['date', 'close', '&-up_or_down'] if col in dataframe.columns]
logger.info(f"Output DataFrame head:\n{dataframe[available_columns].head().to_string()}")
if "&-up_or_down" not in dataframe.columns:
logger.error("FreqAI did not generate the required &-up_or_down column")
raise KeyError("FreqAI did not generate the required &-up_or_down column")
# RSI
dataframe["rsi"] = ta.RSI(dataframe)
# Bollinger Bands
bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=20, stds=2)
dataframe["bb_lowerband"] = bollinger["lower"]
dataframe["bb_middleband"] = bollinger["mid"]
dataframe["bb_upperband"] = bollinger["upper"]
dataframe["bb_percent"] = (dataframe["close"] - dataframe["bb_lowerband"]) / (
dataframe["bb_upperband"] - dataframe["bb_lowerband"]
)
dataframe["bb_width"] = (dataframe["bb_upperband"] - dataframe["bb_lowerband"]) / dataframe[
"bb_middleband"
]
# TEMA
dataframe["tema"] = ta.TEMA(dataframe, timeperiod=9)
return dataframe
def populate_entry_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
df.loc[
(
(qtpylib.crossed_above(df["rsi"], self.buy_rsi.value))
& (df["tema"] <= df["bb_middleband"])
& (df["tema"] > df["tema"].shift(1))
& (df["volume"] > 0)
),
"enter_long",
] = 1
return df
def populate_exit_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
df.loc[
(
(qtpylib.crossed_above(df["rsi"], self.sell_rsi.value))
& (df["tema"] > df["bb_middleband"])
& (df["tema"] < df["tema"].shift(1))
& (df["volume"] > 0)
),
"exit_long",
] = 1
return df

View File

@ -1,32 +0,0 @@
{
"strategy_name": "FreqaiExampleStrategy",
"params": {
"trailing": {
"trailing_stop": true,
"trailing_stop_positive": 0.01,
"trailing_stop_positive_offset": 0.02,
"trailing_only_offset_is_reached": false
},
"max_open_trades": {
"max_open_trades": 4
},
"buy": {
"buy_rsi": 39.92672300850069
},
"sell": {
"sell_rsi": 69.92672300850067
},
"protection": {},
"roi": {
"0": 0.132,
"8": 0.047,
"14": 0.007,
"60": 0
},
"stoploss": {
"stoploss": -0.322
}
},
"ft_stratparam_v": 1,
"export_time": "2025-04-23 12:30:05.550433+00:00"
}

View File

@ -30,7 +30,7 @@ class FreqaiExampleStrategy(IStrategy):
# FreqAI 配置
freqai_info = {
"model": "CatboostClassifier", # 与config保持一致
"model": "XGBoostRegressor", # 与config保持一致
"feature_parameters": {
"include_timeframes": ["3m", "15m", "1h"], # 与config一致
"include_corr_pairlist": ["BTC/USDT", "SOL/USDT"], # 添加相关交易对

View File

@ -1,336 +0,0 @@
import logging
import numpy as np
from functools import reduce
import talib.abstract as ta
from pandas import DataFrame
from technical import qtpylib
from freqtrade.strategy import IStrategy, IntParameter, DecimalParameter
logger = logging.getLogger(__name__)
class FreqaiExampleStrategy(IStrategy):
# 移除硬编码的 minimal_roi 和 stoploss改为动态适配
minimal_roi = {} # 将在 populate_indicators 中动态生成
stoploss = 0.0 # 将在 populate_indicators 中动态设置
trailing_stop = True
process_only_new_candles = True
use_exit_signal = True
startup_candle_count: int = 40
can_short = False
# 参数定义FreqAI 动态适配 buy_rsi 和 sell_rsi禁用 Hyperopt 优化
buy_rsi = IntParameter(low=10, high=50, default=27, space="buy", optimize=False, load=True)
sell_rsi = IntParameter(low=50, high=90, default=59, space="sell", optimize=False, load=True)
# 为 Hyperopt 优化添加 ROI 和 stoploss 参数
roi_0 = DecimalParameter(low=0.01, high=0.2, default=0.038, space="roi", optimize=True, load=True)
roi_15 = DecimalParameter(low=0.005, high=0.1, default=0.027, space="roi", optimize=True, load=True)
roi_30 = DecimalParameter(low=0.001, high=0.05, default=0.009, space="roi", optimize=True, load=True)
stoploss_param = DecimalParameter(low=-0.35, high=-0.1, default=-0.182, space="stoploss", optimize=True, load=True)
# FreqAI 配置
freqai_info = {
"model": "CatboostClassifier", # 与config保持一致
"feature_parameters": {
"include_timeframes": ["3m", "15m", "1h"], # 与config一致
"include_corr_pairlist": ["BTC/USDT", "SOL/USDT"], # 添加相关交易对
"label_period_candles": 20, # 与config一致
"include_shifted_candles": 2, # 与config一致
},
"data_split_parameters": {
"test_size": 0.2,
"shuffle": True, # 启用shuffle
},
"model_training_parameters": {
"n_estimators": 100, # 减少树的数量
"learning_rate": 0.1, # 提高学习率
"max_depth": 6, # 限制树深度
"subsample": 0.8, # 添加子采样
"colsample_bytree": 0.8, # 添加特征采样
"objective": "reg:squarederror",
"eval_metric": "rmse",
"early_stopping_rounds": 20,
"verbose": 0,
},
"data_kitchen": {
"feature_parameters": {
"DI_threshold": 1.5, # 降低异常值过滤阈值
"use_DBSCAN_to_remove_outliers": False # 禁用DBSCAN
}
}
}
plot_config = {
"main_plot": {},
"subplots": {
"&-buy_rsi": {"&-buy_rsi": {"color": "green"}},
"&-sell_rsi": {"&-sell_rsi": {"color": "red"}},
"&-stoploss": {"&-stoploss": {"color": "purple"}},
"&-roi_0": {"&-roi_0": {"color": "orange"}},
"do_predict": {"do_predict": {"color": "brown"}},
},
}
def featcaure_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: dict, **kwargs) -> DataFrame:
# 保留关键的技术指标
dataframe["rsi"] = ta.RSI(dataframe, timeperiod=14)
# 确保 MACD 列被正确计算并保留
try:
macd = ta.MACD(dataframe, fastperiod=12, slowperiod=26, signalperiod=9)
dataframe["macd"] = macd["macd"]
dataframe["macdsignal"] = macd["macdsignal"]
except Exception as e:
logger.error(f"计算 MACD 列时出错:{str(e)}")
dataframe["macd"] = np.nan
dataframe["macdsignal"] = np.nan
# 检查 MACD 列是否存在
if "macd" not in dataframe.columns or "macdsignal" not in dataframe.columns:
logger.error("MACD 或 MACD 信号列缺失,无法生成买入信号")
raise ValueError("DataFrame 缺少必要的 MACD 列")
# 确保 MACD 列存在
if "macd" not in dataframe.columns or "macdsignal" not in dataframe.columns:
logger.error("MACD 或 MACD 信号列缺失,无法生成买入信号")
raise ValueError("DataFrame 缺少必要的 MACD 列")
# 保留布林带相关特征
bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=20, stds=2)
dataframe["bb_lowerband"] = bollinger["lower"]
dataframe["bb_middleband"] = bollinger["mid"]
dataframe["bb_upperband"] = bollinger["upper"]
# 保留成交量相关特征
dataframe["volume_ma"] = dataframe["volume"].rolling(window=20).mean()
# 数据清理
for col in dataframe.columns:
if dataframe[col].dtype in ["float64", "int64"]:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan)
dataframe[col] = dataframe[col].ffill().fillna(0)
logger.info(f"特征工程完成,特征数量:{len(dataframe.columns)}")
return dataframe
def feature_engineering_expand_basic(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
dataframe["%-pct-change"] = dataframe["close"].pct_change()
dataframe["%-raw_volume"] = dataframe["volume"]
dataframe["%-raw_price"] = dataframe["close"]
# 数据清理逻辑
for col in dataframe.columns:
if dataframe[col].dtype in ["float64", "int64"]:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], 0)
dataframe[col] = dataframe[col].ffill()
dataframe[col] = dataframe[col].fillna(0)
# 检查是否仍有无效值
if dataframe[col].isna().any() or np.isinf(dataframe[col]).any():
logger.warning(f"{col} 仍包含无效值,已填充为默认值")
dataframe[col] = dataframe[col].fillna(0)
return dataframe
def feature_engineering_standard(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
dataframe["%-day_of_week"] = dataframe["date"].dt.dayofweek
dataframe["%-hour_of_day"] = dataframe["date"].dt.hour
dataframe.replace([np.inf, -np.inf], 0, inplace=True)
dataframe.ffill(inplace=True)
dataframe.fillna(0, inplace=True)
return dataframe
def set_freqai_targets(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
logger.info(f"设置 FreqAI 目标,交易对:{metadata['pair']}")
if "close" not in dataframe.columns:
logger.error("数据框缺少必要的 'close'")
raise ValueError("数据框缺少必要的 'close'")
try:
label_period = self.freqai_info["feature_parameters"]["label_period_candles"]
# 定义目标变量为未来价格变化百分比(连续值)
dataframe["up_or_down"] = (
dataframe["close"].shift(-label_period) - dataframe["close"]
) / dataframe["close"]
# 数据清理:处理 NaN 和 Inf 值
dataframe["up_or_down"] = dataframe["up_or_down"].replace([np.inf, -np.inf], np.nan)
dataframe["up_or_down"] = dataframe["up_or_down"].ffill().fillna(0)
# 确保目标变量是二维数组
if dataframe["up_or_down"].ndim == 1:
dataframe["up_or_down"] = dataframe["up_or_down"].values.reshape(-1, 1)
# 检查并处理 NaN 或无限值
dataframe["up_or_down"] = dataframe["up_or_down"].replace([np.inf, -np.inf], np.nan)
dataframe["up_or_down"] = dataframe["up_or_down"].ffill().fillna(0)
# 生成 %-volatility 特征
dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std()
# 确保 &-buy_rsi 列的值计算正确
dataframe["&-buy_rsi"] = ta.RSI(dataframe, timeperiod=14)
# 数据清理
for col in ["&-buy_rsi", "up_or_down", "%-volatility"]:
# 使用直接操作避免链式赋值
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan)
dataframe[col] = dataframe[col].ffill() # 替代 fillna(method='ffill')
dataframe[col] = dataframe[col].fillna(dataframe[col].mean()) # 使用均值填充 NaN 值
if dataframe[col].isna().any():
logger.warning(f"目标列 {col} 仍包含 NaN填充为默认值")
except Exception as e:
logger.error(f"创建 FreqAI 目标失败:{str(e)}")
raise
# Log the shape of the target variable for debugging
logger.info(f"目标列形状:{dataframe['up_or_down'].shape}")
logger.info(f"目标列预览:\n{dataframe[['up_or_down', '&-buy_rsi']].head().to_string()}")
return dataframe
def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
logger.info(f"处理交易对:{metadata['pair']}")
dataframe = self.freqai.start(dataframe, metadata, self)
# 计算传统指标
dataframe["rsi"] = ta.RSI(dataframe, timeperiod=14)
bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=20, stds=2)
dataframe["bb_lowerband"] = bollinger["lower"]
dataframe["bb_middleband"] = bollinger["mid"]
dataframe["bb_upperband"] = bollinger["upper"]
dataframe["tema"] = ta.TEMA(dataframe, timeperiod=9)
# 生成 up_or_down 信号(非 FreqAI 目标)
label_period = self.freqai_info["feature_parameters"]["label_period_candles"]
# 使用未来价格变化方向生成 up_or_down 信号
label_period = self.freqai_info["feature_parameters"]["label_period_candles"]
dataframe["up_or_down"] = np.where(
dataframe["close"].shift(-label_period) > dataframe["close"], 1, 0
)
# 动态设置参数
if "&-buy_rsi" in dataframe.columns:
# 派生其他目标
dataframe["&-sell_rsi"] = dataframe["&-buy_rsi"] + 30
dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std()
# Ensure proper calculation and handle potential NaN values
dataframe["&-stoploss"] = (-0.1 - (dataframe["%-volatility"] * 10).clip(0, 0.25)).fillna(-0.1)
dataframe["&-roi_0"] = ((dataframe["close"] / dataframe["close"].shift(label_period) - 1).clip(0, 0.2)).fillna(0)
# Additional check to ensure no NaN values remain
for col in ["&-stoploss", "&-roi_0"]:
if dataframe[col].isna().any():
logger.warning(f"{col} 仍包含 NaN填充为默认值")
dataframe[col] = dataframe[col].fillna(-0.1 if col == "&-stoploss" else 0)
# 简化动态参数生成逻辑
# 放松 buy_rsi 和 sell_rsi 的生成逻辑
# 计算 buy_rsi_pred 并清理 NaN 值
dataframe["buy_rsi_pred"] = dataframe["rsi"].rolling(window=10).mean().clip(30, 50)
dataframe["buy_rsi_pred"] = dataframe["buy_rsi_pred"].fillna(dataframe["buy_rsi_pred"].median())
# 计算 sell_rsi_pred 并清理 NaN 值
dataframe["sell_rsi_pred"] = dataframe["buy_rsi_pred"] + 20
dataframe["sell_rsi_pred"] = dataframe["sell_rsi_pred"].fillna(dataframe["sell_rsi_pred"].median())
# 计算 stoploss_pred 并清理 NaN 值
dataframe["stoploss_pred"] = -0.1 - (dataframe["%-volatility"] * 10).clip(0, 0.25)
dataframe["stoploss_pred"] = dataframe["stoploss_pred"].fillna(dataframe["stoploss_pred"].mean())
# 计算 roi_0_pred 并清理 NaN 值
dataframe["roi_0_pred"] = dataframe["&-roi_0"].clip(0.01, 0.2)
dataframe["roi_0_pred"] = dataframe["roi_0_pred"].fillna(dataframe["roi_0_pred"].mean())
# 检查预测值
for col in ["buy_rsi_pred", "sell_rsi_pred", "stoploss_pred", "roi_0_pred", "&-sell_rsi", "&-stoploss", "&-roi_0"]:
if dataframe[col].isna().any():
logger.warning(f"{col} 包含 NaN填充为默认值")
dataframe[col] = dataframe[col].fillna(dataframe[col].mean())
# 更保守的止损和止盈设置
dataframe["trailing_stop_positive"] = (dataframe["roi_0_pred"] * 0.3).clip(0.01, 0.2)
dataframe["trailing_stop_positive_offset"] = (dataframe["roi_0_pred"] * 0.5).clip(0.01, 0.3)
# 设置策略级参数
self.buy_rsi.value = float(dataframe["buy_rsi_pred"].iloc[-1])
self.sell_rsi.value = float(dataframe["sell_rsi_pred"].iloc[-1])
# 更保守的止损设置
self.stoploss = -0.15 # 固定止损 15%
self.minimal_roi = {
0: float(self.roi_0.value),
15: float(self.roi_15.value),
30: float(self.roi_30.value),
60: 0
}
# 更保守的追踪止损设置
self.trailing_stop_positive = 0.05 # 追踪止损触发点
self.trailing_stop_positive_offset = 0.1 # 追踪止损偏移量
logger.info(f"动态参数buy_rsi={self.buy_rsi.value}, sell_rsi={self.sell_rsi.value}, "
f"stoploss={self.stoploss}, trailing_stop_positive={self.trailing_stop_positive}")
dataframe.replace([np.inf, -np.inf], 0, inplace=True)
dataframe.ffill(inplace=True)
dataframe.fillna(0, inplace=True)
logger.info(f"up_or_down 值统计:\n{dataframe['up_or_down'].value_counts().to_string()}")
logger.info(f"do_predict 值统计:\n{dataframe['do_predict'].value_counts().to_string()}")
return dataframe
def populate_exit_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
# 改进卖出信号条件
exit_long_conditions = [
(df["rsi"] > df["sell_rsi_pred"]), # RSI 高于卖出阈值
(df["volume"] > df["volume"].rolling(window=10).mean()), # 成交量高于近期均值
(df["close"] < df["bb_middleband"]) # 价格低于布林带中轨
]
if exit_long_conditions:
df.loc[
reduce(lambda x, y: x & y, exit_long_conditions),
"exit_long"
] = 1
return df
def populate_entry_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
# 改进买入信号条件
# 检查 MACD 列是否存在
if "macd" not in df.columns or "macdsignal" not in df.columns:
logger.error("MACD 或 MACD 信号列缺失,无法生成买入信号。尝试重新计算 MACD 列。")
try:
macd = ta.MACD(df, fastperiod=12, slowperiod=26, signalperiod=9)
df["macd"] = macd["macd"]
df["macdsignal"] = macd["macdsignal"]
logger.info("MACD 列已成功重新计算。")
except Exception as e:
logger.error(f"重新计算 MACD 列时出错:{str(e)}")
raise ValueError("DataFrame 缺少必要的 MACD 列且无法重新计算。")
enter_long_conditions = [
(df["rsi"] < df["buy_rsi_pred"]), # RSI 低于买入阈值
(df["volume"] > df["volume"].rolling(window=10).mean() * 1.2), # 成交量高于近期均值20%
(df["close"] > df["bb_middleband"]) # 价格高于布林带中轨
]
# 如果 MACD 列存在,则添加 MACD 金叉条件
if "macd" in df.columns and "macdsignal" in df.columns:
enter_long_conditions.append((df["macd"] > df["macdsignal"]))
# 确保模型预测为买入
enter_long_conditions.append((df["do_predict"] == 1))
if enter_long_conditions:
df.loc[
reduce(lambda x, y: x & y, enter_long_conditions),
["enter_long", "enter_tag"]
] = (1, "long")
return df
def confirm_trade_entry(
self, pair: str, order_type: str, amount: float, rate: float,
time_in_force: str, current_time, entry_tag, side: str, **kwargs
) -> bool:
df, _ = self.dp.get_analyzed_dataframe(pair, self.timeframe)
last_candle = df.iloc[-1].squeeze()
if side == "long":
if rate > (last_candle["close"] * (1 + 0.0025)):
return False
return True