myTestFreqAI/freqtrade/templates/freqaiprimer.py
2025-06-01 03:32:52 +00:00

366 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import numpy as np
import datetime
import os
import json
import glob
from functools import reduce
from freqtrade.persistence import Trade
import talib.abstract as ta
from pandas import DataFrame
from typing import Dict
from freqtrade.strategy import IStrategy
logger = logging.getLogger(__name__)
class FreqaiPrimer(IStrategy):
minimal_roi = {
"0": 0.02,
"30": 0.01,
"60": 0
}
stoploss = -0.015
timeframe = "3m"
use_custom_stoploss = False
plot_config = {
"main_plot": {
"ema200": {"color": "blue"},
"bb_upperband": {"color": "gray"},
"bb_lowerband": {"color": "gray"},
"bb_middleband": {"color": "gray"}
},
"subplots": {
"Signals": {
"enter_long": {"color": "green"},
"exit_long": {"color": "red"}
},
"Price-Value Divergence": {
"&-price_value_divergence": {"color": "purple"}
},
"Volume Z-Score": {
"volume_z_score": {"color": "orange"}
},
"RSI": {
"rsi": {"color": "cyan"}
}
}
}
freqai_info = {
"identifier": "test58",
"model": "LightGBMRegressor",
"feature_parameters": {
"include_timeframes": ["3m", "15m", "1h"],
"label_period_candles": 12,
"include_shifted_candles": 3,
},
"data_split_parameters": {
"test_size": 0.2,
"shuffle": False,
},
"model_training_parameters": {
"n_estimators": 200,
"learning_rate": 0.05,
"num_leaves": 31,
"verbose": -1,
},
"fit_live_predictions_candles": 100,
"live_retrain_candles": 100,
}
def __init__(self, config: dict, *args, **kwargs):
super().__init__(config, *args, **kwargs)
logger.setLevel(logging.DEBUG) # 保持 DEBUG 级别以查看更多日志
logger.debug("✅ 策略已初始化,日志级别设置为 DEBUG")
self.trailing_stop_enabled = False
self.trailing_stop_start = 0.03
self.trailing_stop_distance = 0.01
self.pair_stats = {}
self.stats_logged = False
self.fit_live_predictions_candles = self.freqai_info.get("fit_live_predictions_candles", 100)
def feature_engineering_expand_all(self, dataframe: DataFrame, period: int, metadata: dict, **kwargs) -> DataFrame:
dataframe["%-rsi-period"] = ta.RSI(dataframe, timeperiod=period)
dataframe["%-sma-period"] = ta.SMA(dataframe, timeperiod=period)
dataframe["%-ema-period"] = ta.EMA(dataframe, timeperiod=period)
real = ta.TYPPRICE(dataframe)
upperband, middleband, lowerband = ta.BBANDS(real, timeperiod=period, nbdevup=2.0, nbdevdn=2.0)
dataframe["bb_lowerband-period"] = lowerband
dataframe["bb_upperband-period"] = upperband
dataframe["bb_middleband-period"] = middleband
dataframe["%-bb_width-period"] = (dataframe["bb_upperband-period"] - dataframe["bb_lowerband-period"]) / dataframe["bb_middleband-period"]
dataframe["%-mfi-period"] = ta.MFI(dataframe, timeperiod=period)
dataframe["%-adx-period"] = ta.ADX(dataframe, timeperiod=period)
dataframe["%-relative_volume-period"] = dataframe["volume"] / dataframe["volume"].rolling(period).mean()
dataframe["ema200"] = ta.EMA(dataframe, timeperiod=200)
dataframe["%-price_value_divergence"] = (dataframe["close"] - dataframe["ema200"]) / dataframe["ema200"]
columns_to_clean = [
"%-rsi-period", "%-mfi-period", "%-sma-period", "%-ema-period", "%-adx-period",
"bb_lowerband-period", "bb_middleband-period", "bb_upperband-period",
"%-bb_width-period", "%-relative_volume-period", "%-price_value_divergence"
]
for col in columns_to_clean:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], 0).ffill().fillna(0)
pair = metadata.get('pair', 'Unknown')
logger.debug(f"[{pair}] 特征工程完成,列:{list(dataframe.columns)}")
return dataframe
def set_freqai_targets(self, dataframe: DataFrame, metadata: dict, **kwargs) -> DataFrame:
pair = metadata.get('pair', 'Unknown')
if len(dataframe) < 200:
logger.warning(f"[{pair}] 数据量不足({len(dataframe)}根K线需要至少200根K线进行训练")
return dataframe
dataframe["ema200"] = ta.EMA(dataframe, timeperiod=200)
dataframe["&-price_value_divergence"] = (dataframe["close"] - dataframe["ema200"]) / dataframe["ema200"]
dataframe["volume_mean_20"] = dataframe["volume"].rolling(20).mean()
dataframe["volume_std_20"] = dataframe["volume"].rolling(20).std()
dataframe["volume_z_score"] = (dataframe["volume"] - dataframe["volume_mean_20"]) / dataframe["volume_std_20"]
dataframe["&-price_value_divergence"] = dataframe["&-price_value_divergence"].replace([np.inf, -np.inf], 0).ffill().fillna(0)
dataframe["volume_z_score"] = dataframe["volume_z_score"].replace([np.inf, -np.inf], 0).ffill().fillna(0)
logger.debug(f"[{pair}] 目标列生成完成,列:{list(dataframe.columns)}")
return dataframe
def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
pair = metadata.get('pair', 'Unknown')
logger.info(f"[{pair}] 当前可用列调用FreqAI前{list(dataframe.columns)}")
# 计算200周期EMA和历史价值背离
dataframe["ema200"] = ta.EMA(dataframe, timeperiod=200)
dataframe["price_value_divergence"] = (dataframe["close"] - dataframe["ema200"]) / dataframe["ema200"]
# 调用FreqAI预测价值背离
if not hasattr(self, 'freqai') or self.freqai is None:
logger.error(f"[{pair}] FreqAI 未初始化,请确保回测命令中启用了 --freqai")
dataframe["&-price_value_divergence"] = dataframe["price_value_divergence"]
else:
logger.debug(f"self.freqai 类型:{type(self.freqai)}")
dataframe = self.freqai.start(dataframe, metadata, self)
if "&-price_value_divergence" not in dataframe.columns:
logger.warning(f"[{pair}] 回归模型未生成 &-price_value_divergence回退到规则计算")
dataframe["&-price_value_divergence"] = dataframe["price_value_divergence"]
# 计算其他指标
upperband, middleband, lowerband = ta.BBANDS(dataframe["close"], timeperiod=20, nbdevup=2.0, nbdevdn=2.0)
dataframe["bb_upperband"] = upperband
dataframe["bb_middleband"] = middleband
dataframe["bb_lowerband"] = lowerband
dataframe["rsi"] = ta.RSI(dataframe, timeperiod=14)
dataframe["volume_mean_20"] = dataframe["volume"].rolling(20).mean()
dataframe["volume_std_20"] = dataframe["volume"].rolling(20).std()
dataframe["volume_z_score"] = (dataframe["volume"] - dataframe["volume_mean_20"]) / dataframe["volume_std_20"]
# 数据清理
for col in ["ema200", "bb_upperband", "bb_middleband", "bb_lowerband", "rsi", "volume_z_score", "&-price_value_divergence", "price_value_divergence"]:
dataframe[col] = dataframe[col].replace([np.inf, -np.inf], 0).ffill().fillna(0)
# 添加调试日志:打印关键指标
logger.debug(f"[{pair}] 最新数据 - close{dataframe['close'].iloc[-1]:.6f}, "
f"rsi{dataframe['rsi'].iloc[-1]:.2f}, "
f"&-price_value_divergence{dataframe['&-price_value_divergence'].iloc[-1]:.6f}, "
f"volume_z_score{dataframe['volume_z_score'].iloc[-1]:.2f}, "
f"bb_lowerband{dataframe['bb_lowerband'].iloc[-1]:.6f}")
# 获取 labels_mean 和 labels_std
labels_mean = None
labels_std = None
logger.debug(f"freqai_info identifier{self.freqai_info['identifier']}")
logger.debug(f"user_data_dir{self.config['user_data_dir']}")
try:
model_base_dir = os.path.join(self.config["user_data_dir"], "models", self.freqai_info["identifier"])
pair_base = pair.split('/')[0] if '/' in pair else pair
sub_dirs = glob.glob(os.path.join(model_base_dir, f"sub-train-{pair_base}_*"))
if not sub_dirs:
logger.warning(f"[{pair}] 未找到任何子目录:{model_base_dir}/sub-train-{pair_base}_*")
else:
latest_sub_dir = max(sub_dirs, key=lambda x: int(x.split('_')[-1]))
pair_base_lower = pair_base.lower()
timestamp = latest_sub_dir.split('_')[-1]
metadata_file = os.path.join(latest_sub_dir, f"cb_{pair_base_lower}_{timestamp}_metadata.json")
if os.path.exists(metadata_file):
with open(metadata_file, "r") as f:
metadata = json.load(f)
labels_mean = metadata["labels_mean"]["&-price_value_divergence"]
labels_std = metadata["labels_std"]["&-price_value_divergence"]
logger.info(f"[{pair}] 从最新子目录 {latest_sub_dir} 读取 labels_mean{labels_mean}, labels_std{labels_std}")
else:
logger.warning(f"[{pair}] 最新的 metadata.json 文件 {metadata_file} 不存在")
except Exception as e:
logger.warning(f"[{pair}] 无法从子目录读取 labels_mean 和 labels_std{e},重新计算")
if labels_mean is None or labels_std is None:
logger.warning(f"[{pair}] 无法获取 labels_mean 和 labels_std重新计算")
dataframe["&-price_value_divergence_actual"] = (dataframe["close"] - dataframe["ema200"]) / dataframe["ema200"]
dataframe["&-price_value_divergence_actual"] = dataframe["&-price_value_divergence_actual"].replace([np.inf, -np.inf], 0).ffill().fillna(0)
recent_data = dataframe["&-price_value_divergence_actual"].tail(self.fit_live_predictions_candles)
labels_mean = recent_data.mean()
labels_std = recent_data.std()
if np.isnan(labels_std) or labels_std == 0:
labels_std = 0.01
logger.warning(f"[{pair}] labels_std 计算异常,使用默认值 0.01")
self.pair_stats[pair] = {"labels_mean": labels_mean, "labels_std": labels_std}
if labels_std > 0.015:
k_buy = 1.2
k_sell = 1.5
elif labels_std < 0.010:
k_buy = 0.8
k_sell = 1.0
else:
k_buy = 1.0
k_sell = 1.2
if labels_mean > 0.015:
k_sell += 0.5
logger.info(f"[{pair}] labels_mean 较高({labels_mean:.4f}),增加 k_sell 到 {k_sell:.2f}")
self.buy_threshold = labels_mean - k_buy * labels_std
self.sell_threshold = labels_mean + k_sell * labels_std
self.buy_threshold = max(self.buy_threshold, -0.05)
self.buy_threshold = min(self.buy_threshold, -0.005)
self.sell_threshold = min(self.sell_threshold, 0.05)
self.sell_threshold = max(self.sell_threshold, 0.005)
logger.info(f"[{pair}] labels_mean{labels_mean:.4f}, labels_std{labels_std:.4f}")
logger.info(f"[{pair}] k_buy{k_buy:.2f}, k_sell{k_sell:.2f}")
logger.info(f"[{pair}] 动态买入阈值:{self.buy_threshold:.4f}")
logger.info(f"[{pair}] 动态卖出阈值:{self.sell_threshold:.4f}")
if not self.stats_logged:
logger.info("===== 所有币对的 labels_mean 和 labels_std 汇总 =====")
for p, stats in self.pair_stats.items():
logger.info(f"[{p}] labels_mean{stats['labels_mean']:.4f}, labels_std{stats['labels_std']:.4f}")
logger.info("==============================================")
self.stats_logged = True
return dataframe
def populate_entry_trend(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
pair = metadata.get('pair', 'Unknown')
conditions = []
if "&-price_value_divergence" in dataframe.columns:
# 逐个检查买入条件
cond1 = (dataframe["&-price_value_divergence"] < self.buy_threshold)
cond2 = (dataframe["volume_z_score"] > 1.5)
cond3 = (dataframe["rsi"] < 40)
cond4 = (dataframe["close"] <= dataframe["bb_lowerband"])
buy_condition = cond1 & cond2 & cond3 & cond4
conditions.append(buy_condition)
# 添加调试日志:打印条件是否满足
logger.debug(f"[{pair}] 买入条件检查 - "
f"&-price_value_divergence < {self.buy_threshold:.6f}: {cond1.iloc[-1]}, "
f"volume_z_score > 1.5: {cond2.iloc[-1]}, "
f"rsi < 40: {cond3.iloc[-1]}, "
f"close <= bb_lowerband: {cond4.iloc[-1]}")
else:
logger.warning(f"[{pair}] ⚠️ &-price_value_divergence 列缺失,跳过该条件")
if len(conditions) > 0:
dataframe.loc[reduce(lambda x, y: x & y, conditions), 'enter_long'] = 1
# 检查是否同时有卖出信号
if 'exit_long' in dataframe.columns and (dataframe["exit_long"] == 1).any():
logger.warning(f"[{pair}] 同时检测到买入和卖出信号,忽略买入信号")
dataframe['enter_long'] = 0
else:
logger.debug(f"[{pair}] 入场信号触发,条件满足") # 改为 DEBUG 级别
return dataframe
def populate_exit_trend(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
pair = metadata.get('pair', 'Unknown')
conditions = []
if "&-price_value_divergence" in dataframe.columns:
# 逐个检查卖出条件
cond1 = (dataframe["&-price_value_divergence"] > self.sell_threshold)
cond2 = (dataframe["rsi"] > 75)
sell_condition = cond1 | cond2
conditions.append(sell_condition)
# 添加调试日志:打印条件是否满足
logger.debug(f"[{pair}] 卖出条件检查 - "
f"&-price_value_divergence > {self.sell_threshold:.6f}: {cond1.iloc[-1]}, "
f"rsi > 75: {cond2.iloc[-1]}")
else:
logger.warning(f"[{pair}] ⚠️ &-price_value_divergence 列缺失,跳过该条件")
if len(conditions) > 0:
dataframe.loc[reduce(lambda x, y: x & y, conditions), 'exit_long'] = 1
logger.debug(f"[{pair}] 出场信号触发,条件满足") # 改为 DEBUG 级别
return dataframe
def adjust_trade_position(self, trade: Trade, current_time: datetime,
current_rate: float, current_profit: float,
min_roi: Dict[float, float], max_profit: float):
hold_time = (current_time - trade.open_date_utc).total_seconds() / 60
if hold_time < 15:
logger.info(f"[{trade.pair}] 持仓时间 {hold_time:.1f} 分钟,未达到最小持仓时间 15 分钟,暂不退出")
return None
profit_ratio = (current_rate - trade.open_rate) / trade.open_rate
if profit_ratio >= self.trailing_stop_start and not self.trailing_stop_enabled:
self.trailing_stop_enabled = True
trade.adjust_max_rate(current_rate)
logger.info(f"[{trade.pair}] 价格上涨超过 {self.trailing_stop_start*100:.1f}%,启动 Trailing Stop")
if self.trailing_stop_enabled:
max_rate = trade.max_rate
trailing_stop_price = max_rate * (1 - self.trailing_stop_distance)
if current_rate < trailing_stop_price:
logger.info(f"[{trade.pair}] 价格回落至 Trailing Stop 点 {trailing_stop_price:.6f},触发卖出")
return -1
trade.adjust_max_rate(current_rate)
if hold_time > 30:
logger.info(f"[{trade.pair}] 持仓时间超过30分钟强制平仓")
return -1
return None
def confirm_trade_entry(self, pair: str, order_type: str, amount: float, rate: float,
time_in_force: str, current_time: datetime, **kwargs) -> bool:
# 添加调试日志:检查是否允许买入
recent_trades = Trade.query.filter(
and_(
Trade.pair == pair,
Trade.is_open == False, # noqa: E712
Trade.close_date > current_time - datetime.timedelta(minutes=5)
)
).all()
if len(recent_trades) > 0:
logger.info(f"[{pair}] 5分钟内有近期交易{len(recent_trades)} 笔),跳过本次入场")
return False
# 检查其他限制(例如资金、仓位等)
logger.debug(f"[{pair}] 允许买入 - 订单类型:{order_type}, 数量:{amount:.2f}, 价格:{rate:.6f}, 时间:{current_time}")
self.trailing_stop_enabled = False
return True
def confirm_trade_exit(self, pair: str, trade: Trade, order_type: str, amount: float,
rate: float, time_in_force: str, exit_reason: str,
current_time: datetime, **kwargs) -> bool:
logger.info(f"[{pair}] 退出交易,原因:{exit_reason}, 利润:{trade.calc_profit_ratio(rate):.2%}")
return True