sma
This commit is contained in:
parent
fdc1fa0374
commit
5df9f23baf
@ -72,32 +72,38 @@ class FreqaiExampleStrategy(IStrategy):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def featcaure_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: dict, **kwargs) -> DataFrame:
|
def featcaure_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: dict, **kwargs) -> DataFrame:
|
||||||
# 保留关键的技术指标
|
# 计算关键指标
|
||||||
dataframe["rsi"] = ta.RSI(dataframe, timeperiod=14)
|
dataframe["rsi"] = ta.RSI(dataframe, timeperiod=14)
|
||||||
|
|
||||||
# 计算短期和长期 SMA
|
# 计算短期和长期SMA
|
||||||
dataframe["sma_short"] = ta.SMA(dataframe, timeperiod=12)
|
dataframe["sma_short"] = ta.SMA(dataframe, timeperiod=12)
|
||||||
dataframe["sma_long"] = ta.SMA(dataframe, timeperiod=26)
|
dataframe["sma_long"] = ta.SMA(dataframe, timeperiod=26)
|
||||||
|
|
||||||
# 检查 SMA 列是否存在
|
# 计算SMA交叉信号
|
||||||
if "sma_short" not in dataframe.columns or "sma_long" not in dataframe.columns:
|
dataframe["sma_cross"] = np.where(
|
||||||
logger.error("SMA 列缺失,无法生成买入信号")
|
dataframe["sma_short"] > dataframe["sma_long"], 1, -1
|
||||||
raise ValueError("DataFrame 缺少必要的 SMA 列")
|
)
|
||||||
|
|
||||||
# 确保 MACD 列存在
|
# 计算布林带
|
||||||
if "macd" not in dataframe.columns or "macdsignal" not in dataframe.columns:
|
|
||||||
logger.error("MACD 或 MACD 信号列缺失,无法生成买入信号")
|
|
||||||
raise ValueError("DataFrame 缺少必要的 MACD 列")
|
|
||||||
|
|
||||||
# 保留布林带相关特征
|
|
||||||
bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=20, stds=2)
|
bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=20, stds=2)
|
||||||
dataframe["bb_lowerband"] = bollinger["lower"]
|
dataframe["bb_lowerband"] = bollinger["lower"]
|
||||||
dataframe["bb_middleband"] = bollinger["mid"]
|
dataframe["bb_middleband"] = bollinger["mid"]
|
||||||
dataframe["bb_upperband"] = bollinger["upper"]
|
dataframe["bb_upperband"] = bollinger["upper"]
|
||||||
|
|
||||||
# 保留成交量相关特征
|
# 计算价格与布林带的关系
|
||||||
|
dataframe["bb_pct"] = (dataframe["close"] - dataframe["bb_lowerband"]) / (
|
||||||
|
dataframe["bb_upperband"] - dataframe["bb_lowerband"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 成交量相关特征
|
||||||
dataframe["volume_ma"] = dataframe["volume"].rolling(window=20).mean()
|
dataframe["volume_ma"] = dataframe["volume"].rolling(window=20).mean()
|
||||||
|
dataframe["volume_pct"] = dataframe["volume"] / dataframe["volume_ma"]
|
||||||
|
|
||||||
|
# 价格变化特征
|
||||||
|
dataframe["pct_change"] = dataframe["close"].pct_change()
|
||||||
|
dataframe["pct_change_5"] = dataframe["close"].pct_change(5)
|
||||||
|
dataframe["pct_change_10"] = dataframe["close"].pct_change(10)
|
||||||
|
|
||||||
# 数据清理
|
# 数据清理
|
||||||
for col in dataframe.columns:
|
for col in dataframe.columns:
|
||||||
if dataframe[col].dtype in ["float64", "int64"]:
|
if dataframe[col].dtype in ["float64", "int64"]:
|
||||||
@ -142,34 +148,36 @@ class FreqaiExampleStrategy(IStrategy):
|
|||||||
label_period = self.freqai_info["feature_parameters"]["label_period_candles"]
|
label_period = self.freqai_info["feature_parameters"]["label_period_candles"]
|
||||||
|
|
||||||
# 定义目标变量为未来价格变化百分比(连续值)
|
# 定义目标变量为未来价格变化百分比(连续值)
|
||||||
dataframe["up_or_down"] = (
|
dataframe["target"] = (
|
||||||
dataframe["close"].shift(-label_period) - dataframe["close"]
|
dataframe["close"].shift(-label_period) - dataframe["close"]
|
||||||
) / dataframe["close"]
|
) / dataframe["close"]
|
||||||
|
|
||||||
|
# 添加辅助目标变量
|
||||||
|
dataframe["target_5"] = (
|
||||||
|
dataframe["close"].shift(-5) - dataframe["close"]
|
||||||
|
) / dataframe["close"]
|
||||||
|
|
||||||
|
dataframe["target_10"] = (
|
||||||
|
dataframe["close"].shift(-10) - dataframe["close"]
|
||||||
|
) / dataframe["close"]
|
||||||
|
|
||||||
# 数据清理:处理 NaN 和 Inf 值
|
# 数据清理:处理 NaN 和 Inf 值
|
||||||
dataframe["up_or_down"] = dataframe["up_or_down"].replace([np.inf, -np.inf], np.nan)
|
for col in ["target", "target_5", "target_10"]:
|
||||||
dataframe["up_or_down"] = dataframe["up_or_down"].ffill().fillna(0)
|
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan)
|
||||||
|
dataframe[col] = dataframe[col].ffill().fillna(0)
|
||||||
|
|
||||||
# 确保目标变量是二维数组
|
# 确保目标变量是二维数组
|
||||||
if dataframe["up_or_down"].ndim == 1:
|
if dataframe["target"].ndim == 1:
|
||||||
dataframe["up_or_down"] = dataframe["up_or_down"].values.reshape(-1, 1)
|
dataframe["target"] = dataframe["target"].values.reshape(-1, 1)
|
||||||
|
|
||||||
# 检查并处理 NaN 或无限值
|
# 生成波动率特征
|
||||||
dataframe["up_or_down"] = dataframe["up_or_down"].replace([np.inf, -np.inf], np.nan)
|
|
||||||
dataframe["up_or_down"] = dataframe["up_or_down"].ffill().fillna(0)
|
|
||||||
|
|
||||||
# 生成 %-volatility 特征
|
|
||||||
dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std()
|
dataframe["%-volatility"] = dataframe["close"].pct_change().rolling(20).std()
|
||||||
|
|
||||||
# 确保 &-buy_rsi 列的值计算正确
|
|
||||||
dataframe["&-buy_rsi"] = ta.RSI(dataframe, timeperiod=14)
|
|
||||||
|
|
||||||
# 数据清理
|
# 数据清理
|
||||||
for col in ["&-buy_rsi", "up_or_down", "%-volatility"]:
|
for col in ["target", "target_5", "target_10", "%-volatility"]:
|
||||||
# 使用直接操作避免链式赋值
|
|
||||||
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan)
|
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan)
|
||||||
dataframe[col] = dataframe[col].ffill() # 替代 fillna(method='ffill')
|
dataframe[col] = dataframe[col].ffill()
|
||||||
dataframe[col] = dataframe[col].fillna(dataframe[col].mean()) # 使用均值填充 NaN 值
|
dataframe[col] = dataframe[col].fillna(dataframe[col].mean())
|
||||||
if dataframe[col].isna().any():
|
if dataframe[col].isna().any():
|
||||||
logger.warning(f"目标列 {col} 仍包含 NaN,填充为默认值")
|
logger.warning(f"目标列 {col} 仍包含 NaN,填充为默认值")
|
||||||
|
|
||||||
@ -178,8 +186,8 @@ class FreqaiExampleStrategy(IStrategy):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
# Log the shape of the target variable for debugging
|
# Log the shape of the target variable for debugging
|
||||||
logger.info(f"目标列形状:{dataframe['up_or_down'].shape}")
|
logger.info(f"目标列形状:{dataframe['target'].shape}")
|
||||||
logger.info(f"目标列预览:\n{dataframe[['up_or_down', '&-buy_rsi']].head().to_string()}")
|
logger.info(f"目标列预览:\n{dataframe[['target', 'target_5', 'target_10']].head().to_string()}")
|
||||||
return dataframe
|
return dataframe
|
||||||
|
|
||||||
def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
|
def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user