主要修改点:
1. 修复了方法名拼写错误(featcaure -> feature) 2. 增加了基础数据检查 3. 添加了更多技术指标(动量、ROC、波动率等) 4. 增加了特征数量检查 5. 添加了更详细的日志记录
This commit is contained in:
parent
5df9f23baf
commit
5027018fe0
@ -71,31 +71,27 @@ class FreqaiExampleStrategy(IStrategy):
|
||||
},
|
||||
}
|
||||
|
||||
def featcaure_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: dict, **kwargs) -> DataFrame:
|
||||
# 计算关键指标
|
||||
def feature_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: dict, **kwargs) -> DataFrame:
|
||||
# 确保基础价格数据存在
|
||||
if "close" not in dataframe.columns:
|
||||
raise ValueError("Dataframe must contain 'close' column")
|
||||
|
||||
# 计算技术指标
|
||||
dataframe["rsi"] = ta.RSI(dataframe, timeperiod=14)
|
||||
|
||||
# 计算短期和长期SMA
|
||||
dataframe["sma_short"] = ta.SMA(dataframe, timeperiod=12)
|
||||
dataframe["sma_long"] = ta.SMA(dataframe, timeperiod=26)
|
||||
dataframe["sma_cross"] = np.where(dataframe["sma_short"] > dataframe["sma_long"], 1, -1)
|
||||
|
||||
# 计算SMA交叉信号
|
||||
dataframe["sma_cross"] = np.where(
|
||||
dataframe["sma_short"] > dataframe["sma_long"], 1, -1
|
||||
)
|
||||
|
||||
# 计算布林带
|
||||
# 布林带
|
||||
bollinger = qtpylib.bollinger_bands(qtpylib.typical_price(dataframe), window=20, stds=2)
|
||||
dataframe["bb_lowerband"] = bollinger["lower"]
|
||||
dataframe["bb_middleband"] = bollinger["mid"]
|
||||
dataframe["bb_upperband"] = bollinger["upper"]
|
||||
|
||||
# 计算价格与布林带的关系
|
||||
dataframe["bb_pct"] = (dataframe["close"] - dataframe["bb_lowerband"]) / (
|
||||
dataframe["bb_upperband"] - dataframe["bb_lowerband"]
|
||||
)
|
||||
|
||||
# 成交量相关特征
|
||||
# 成交量特征
|
||||
dataframe["volume_ma"] = dataframe["volume"].rolling(window=20).mean()
|
||||
dataframe["volume_pct"] = dataframe["volume"] / dataframe["volume_ma"]
|
||||
|
||||
@ -104,13 +100,25 @@ class FreqaiExampleStrategy(IStrategy):
|
||||
dataframe["pct_change_5"] = dataframe["close"].pct_change(5)
|
||||
dataframe["pct_change_10"] = dataframe["close"].pct_change(10)
|
||||
|
||||
# 动量指标
|
||||
dataframe["momentum"] = dataframe["close"] / dataframe["close"].shift(4) - 1
|
||||
dataframe["roc"] = ta.ROC(dataframe, timeperiod=10)
|
||||
|
||||
# 波动率
|
||||
dataframe["volatility"] = dataframe["close"].pct_change().rolling(window=20).std()
|
||||
|
||||
# 数据清理
|
||||
for col in dataframe.columns:
|
||||
if dataframe[col].dtype in ["float64", "int64"]:
|
||||
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], np.nan)
|
||||
dataframe[col] = dataframe[col].ffill().fillna(0)
|
||||
|
||||
logger.info(f"特征工程完成,特征数量:{len(dataframe.columns)}")
|
||||
# 确保至少有一个特征
|
||||
if len(dataframe.columns) == 0:
|
||||
raise ValueError("No features generated in feature engineering")
|
||||
|
||||
logger.info(f"特征工程完成,生成特征数量:{len(dataframe.columns)}")
|
||||
logger.debug(f"特征列表:{list(dataframe.columns)}")
|
||||
return dataframe
|
||||
|
||||
def feature_engineering_expand_basic(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user