从metadata.json里读取labels_mean和labels_std

This commit is contained in:
zhangkun9038@dingtalk.com 2025-06-01 02:31:03 +00:00
parent da55f01830
commit ef4a828678

View File

@ -108,12 +108,14 @@ class FreqaiPrimer(IStrategy):
for col in columns_to_clean:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], 0).ffill().fillna(0)
logger.debug(f"[{metadata['pair']}] 特征工程完成,列:{list(dataframe.columns)}")
pair = metadata.get('pair', 'Unknown')
logger.debug(f"[{pair}] 特征工程完成,列:{list(dataframe.columns)}")
return dataframe
def set_freqai_targets(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
pair = metadata.get('pair', 'Unknown')
if len(dataframe) < 200:
logger.warning(f"[{metadata['pair']}] 数据量不足({len(dataframe)}根K线需要至少200根K线进行训练")
logger.warning(f"[{pair}] 数据量不足({len(dataframe)}根K线需要至少200根K线进行训练")
return dataframe
dataframe["ema200"] = ta.EMA(dataframe, timeperiod=200)
@ -126,11 +128,12 @@ class FreqaiPrimer(IStrategy):
dataframe["&-price_value_divergence"] = dataframe["&-price_value_divergence"].replace([np.inf, -np.inf], 0).ffill().fillna(0)
dataframe["volume_z_score"] = dataframe["volume_z_score"].replace([np.inf, -np.inf], 0).ffill().fillna(0)
logger.debug(f"[{metadata['pair']}] 目标列生成完成,列:{list(dataframe.columns)}")
logger.debug(f"[{pair}] 目标列生成完成,列:{list(dataframe.columns)}")
return dataframe
def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
logger.info(f"[{metadata['pair']}] 当前可用列调用FreqAI前{list(dataframe.columns)}")
pair = metadata.get('pair', 'Unknown')
logger.info(f"[{pair}] 当前可用列调用FreqAI前{list(dataframe.columns)}")
# 计算200周期EMA和历史价值背离
dataframe["ema200"] = ta.EMA(dataframe, timeperiod=200)
@ -138,14 +141,14 @@ class FreqaiPrimer(IStrategy):
# 调用FreqAI预测价值背离
if not hasattr(self, 'freqai') or self.freqai is None:
logger.error(f"[{metadata['pair']}] FreqAI 未初始化,请确保回测命令中启用了 --freqai")
logger.error(f"[{pair}] FreqAI 未初始化,请确保回测命令中启用了 --freqai")
dataframe["&-price_value_divergence"] = dataframe["price_value_divergence"]
else:
logger.debug(f"self.freqai 类型:{type(self.freqai)}")
dataframe = self.freqai.start(dataframe, metadata, self)
if "&-price_value_divergence" not in dataframe.columns:
logger.warning(f"[{metadata['pair']}] 回归模型未生成 &-price_value_divergence回退到规则计算")
logger.warning(f"[{pair}] 回归模型未生成 &-price_value_divergence回退到规则计算")
dataframe["&-price_value_divergence"] = dataframe["price_value_divergence"]
# 计算其他指标
@ -163,7 +166,6 @@ class FreqaiPrimer(IStrategy):
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], 0).ffill().fillna(0)
# 获取 labels_mean 和 labels_std
pair = metadata["pair"]
labels_mean = None
labels_std = None
@ -174,7 +176,7 @@ class FreqaiPrimer(IStrategy):
# 从最新的子目录读取 metadata.json
try:
model_base_dir = os.path.join(self.config["user_data_dir"], "models", self.freqai_info["identifier"])
pair_base = pair.split('/')[0] # 取币对基础部分,例如 "TRUMP/USDT" -> "TRUMP"
pair_base = pair.split('/')[0] if '/' in pair else pair # 取币对基础部分,例如 "TRUMP/USDT" -> "TRUMP"
sub_dirs = glob.glob(os.path.join(model_base_dir, f"sub-train-{pair_base}_*"))
if not sub_dirs:
@ -247,42 +249,45 @@ class FreqaiPrimer(IStrategy):
logger.info("==============================================")
self.stats_logged = True
logger.info(f"[{metadata['pair']}] 指标计算完成,列:{list(dataframe.columns)}")
# 移除问题日志行,避免 KeyError
# logger.info(f"[{metadata['pair']}] 指标计算完成,列:{list(dataframe.columns)}")
return dataframe
def populate_entry_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
def populate_entry_trend(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
pair = metadata.get('pair', 'Unknown')
conditions = []
if "&-price_value_divergence" in df.columns:
buy_condition = (df["&-price_value_divergence"] < self.buy_threshold)
buy_condition &= (df["volume_z_score"] > 1.5)
buy_condition &= (df["rsi"] < 40)
buy_condition &= (df["close"] <= df["bb_lowerband"])
if "&-price_value_divergence" in dataframe.columns:
buy_condition = (dataframe["&-price_value_divergence"] < self.buy_threshold)
buy_condition &= (dataframe["volume_z_score"] > 1.5)
buy_condition &= (dataframe["rsi"] < 40)
buy_condition &= (dataframe["close"] <= dataframe["bb_lowerband"])
conditions.append(buy_condition)
else:
logger.warning("⚠️ &-price_value_divergence 列缺失,跳过该条件")
logger.warning(f"[{pair}] ⚠️ &-price_value_divergence 列缺失,跳过该条件")
if len(conditions) > 0:
df.loc[reduce(lambda x, y: x & y, conditions), 'enter_long'] = 1
logger.info(f"[{metadata['pair']}] 入场信号触发,条件满足")
dataframe.loc[reduce(lambda x, y: x & y, conditions), 'enter_long'] = 1
logger.info(f"[{pair}] 入场信号触发,条件满足")
return df
return dataframe
def populate_exit_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
def populate_exit_trend(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
pair = metadata.get('pair', 'Unknown')
conditions = []
if "&-price_value_divergence" in df.columns:
sell_condition = (df["&-price_value_divergence"] > self.sell_threshold)
sell_condition |= (df["rsi"] > 75)
if "&-price_value_divergence" in dataframe.columns:
sell_condition = (dataframe["&-price_value_divergence"] > self.sell_threshold)
sell_condition |= (dataframe["rsi"] > 75)
conditions.append(sell_condition)
else:
logger.warning("⚠️ &-price_value_divergence 列缺失,跳过该条件")
logger.warning(f"[{pair}] ⚠️ &-price_value_divergence 列缺失,跳过该条件")
if len(conditions) > 0:
df.loc[reduce(lambda x, y: x & y, conditions), 'exit_long'] = 1
logger.info(f"[{metadata['pair']}] 出场信号触发,条件满足")
dataframe.loc[reduce(lambda x, y: x & y, conditions), 'exit_long'] = 1
logger.info(f"[{pair}] 出场信号触发,条件满足")
return df
return dataframe
def adjust_trade_position(self, trade: Trade, current_time: datetime,
current_rate: float, current_profit: float,