myTestFreqAI/tools/evaluation.py
2026-02-02 07:57:46 +08:00

327 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pymysql
import decimal
from collections import defaultdict
from datetime import datetime
import argparse
# -------------------------- 配置项 --------------------------
DB_CONFIG = {
"host": "192.168.1.215", # 数据库地址
"port": 3306, # 端口
"user": "freqai", # 用户名
"password": "Abcd@1234", # 密码
"database": "freqai", # 数据库名
"charset": "utf8mb4" # 字符集
}
MIN_SPAN_DAYS = 5
CUTOFF_DATE = datetime(2025, 9, 10)
SHORT_SPAN_THRESHOLD = 50
# -------------------------- 工具函数 --------------------------
def truncate_commitid(commitid):
return commitid[:8] if commitid else ""
def format_timespan(from_str, to_str):
from_date = from_str.split(" ")[0].replace("-", "")
to_date = to_str.split(" ")[0].replace("-", "")
return f"{from_date}-{to_date}"
def parse_formatted_timespan(formatted_span):
"""解析格式化时段为datetime对象异常时返回None"""
try:
from_date_str, to_date_str = formatted_span.split("-")
from_date = datetime.strptime(from_date_str, "%Y%m%d")
to_date = datetime.strptime(to_date_str, "%Y%m%d")
return from_date, to_date
except (ValueError, IndexError):
return None, None
def get_span_days(formatted_span):
from_date, to_date = parse_formatted_timespan(formatted_span)
if not from_date or not to_date:
return 0
return (to_date - from_date).days + 1
def is_span_valid(formatted_span):
from_date, _ = parse_formatted_timespan(formatted_span)
if not from_date:
return False
span_days = get_span_days(formatted_span)
if from_date < CUTOFF_DATE and span_days < SHORT_SPAN_THRESHOLD:
return False
return True
# 新增:计算缺失时段(目标时段 - 已覆盖时段)
def get_missing_spans(target_spans, covered_spans):
"""返回该CommitID缺失的目标时段列表"""
return [span for span in target_spans if span not in covered_spans]
def validate_timespan_format(formatted_span):
"""校验时段格式是否为 YYYYMMDD-YYYYMMDD且开始日期≤结束日期"""
from_date, to_date = parse_formatted_timespan(formatted_span)
if not from_date or not to_date:
return False
if from_date > to_date:
return False
return True
# -------------------------- 命令行参数解析函数 --------------------------
def parse_command_line_args():
"""解析命令行参数,提取 --timespans 传入的时段列表"""
parser = argparse.ArgumentParser(description="查询指定时段组合下表现最优的CommitID")
parser.add_argument(
"--timespans",
type=str,
required=True,
help="指定目标时段组合格式YYYYMMDD-YYYYMMDD,YYYYMMDD-YYYYMMDD多个时段用逗号分隔"
)
args = parser.parse_args()
# 处理参数:分割逗号,去重,过滤空字符串
target_timespans = [span.strip() for span in args.timespans.split(",") if span.strip()]
# 校验每个时段的格式合法性
valid_timespans = []
invalid_timespans = []
for span in target_timespans:
if validate_timespan_format(span):
valid_timespans.append(span)
else:
invalid_timespans.append(span)
# 输出校验结果
if invalid_timespans:
print(f"⚠️ 以下时段格式不合法需为YYYYMMDD-YYYYMMDD且开始日期≤结束日期已忽略")
for span in invalid_timespans:
print(f" - {span}")
if not valid_timespans:
raise ValueError("❌ 无有效时段传入,程序终止")
return valid_timespans
# -------------------------- 数据库操作 --------------------------
def get_db_connection():
try:
return pymysql.connect(**DB_CONFIG)
except pymysql.MySQLError as e:
print(f"数据库连接失败: {e}")
raise
def get_backtest_data_in_target_spans(target_timespans):
"""提取目标时段组合内的所有回测数据新增保留每个CommitID的已覆盖时段"""
conn = get_db_connection()
cursor = conn.cursor(pymysql.cursors.DictCursor)
# 先把目标时段转换成「原始格式」YYYY-MM-DD HH:MM:SS用于数据库查询
target_span_original = {}
for formatted_span in target_timespans:
from_date, to_date = parse_formatted_timespan(formatted_span)
from_str = from_date.strftime("%Y-%m-%d 00:00:00")
to_str = to_date.strftime("%Y-%m-%d 00:00:00")
target_span_original[formatted_span] = (from_str, to_str)
# 遍历目标时段,查询每个时段的回测数据
target_backtest_data = defaultdict(dict)
# 新增记录每个CommitID的已覆盖时段
cid_covered_spans = defaultdict(list)
for formatted_span, (from_str, to_str) in target_span_original.items():
sql = """
SELECT
JSON_UNQUOTE(JSON_EXTRACT(result, '$.commitid')) AS commitid,
IF(
JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.total_profit')) IS NULL
OR JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.sqn')) IS NULL,
0,
SIGN(CAST(JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.total_profit')) AS DECIMAL(10,4)))
* CAST(JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.total_profit')) AS DECIMAL(10,4))
* CAST(JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.sqn')) AS DECIMAL(10,4))
) AS corrected_product
FROM backtestresult
WHERE
JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.backtesting_from_str')) = %s
AND JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.backtesting_to_str')) = %s;
"""
try:
cursor.execute(sql, (from_str, to_str))
records = cursor.fetchall()
# 计算该时段内所有commitid的相对排名
span_scores = {}
for record in records:
cid = record["commitid"]
score = decimal.Decimal(str(record["corrected_product"]))
if cid:
span_scores[cid] = score
# 按得分排序,计算相对排名分数
if span_scores:
sorted_cids = sorted(span_scores.items(), key=lambda x: x[1], reverse=True)
total_count = len(sorted_cids)
# 为每个commitid分配相对排名分数0-100分
for idx, (cid, score) in enumerate(sorted_cids):
# 使用百分位排名:排名越靠前,分数越高
# 例如第1名=100分第2名=90分...第10名=10分
rank_score = max(0, 100 - idx * 10)
# 存储相对排名分数
if cid not in target_backtest_data[formatted_span] or rank_score > target_backtest_data[formatted_span][cid]:
target_backtest_data[formatted_span][cid] = rank_score
# 记录该CommitID覆盖了当前时段去重避免重复添加
if formatted_span not in cid_covered_spans[cid]:
cid_covered_spans[cid].append(formatted_span)
except pymysql.MySQLError as e:
print(f"查询目标时段 {formatted_span} 失败: {e}")
continue
cursor.close()
conn.close()
return target_backtest_data, cid_covered_spans
# -------------------------- 核心函数寻找目标时段组合下的最优CommitID --------------------------
def find_best_commitid_in_target_spans(target_timespans):
"""寻找指定时段组合下的最优CommitID包含缺失时段"""
# 步骤1获取目标时段回测数据 + 每个CommitID的已覆盖时段
target_backtest_data, cid_covered_spans = get_backtest_data_in_target_spans(target_timespans)
if not target_backtest_data:
print("❌ 目标时段组合内无回测数据")
return None, []
# 步骤2聚合CommitID表现
commitid_agg_score = defaultdict(lambda: {
"total_score": decimal.Decimal("0"),
"avg_score": decimal.Decimal("0"),
"valid_span_count": 0,
"short_id": "",
"covered_spans": [], # 新增:已覆盖时段列表
"missing_spans": [] # 新增:缺失时段列表
})
for formatted_span, cid_scores in target_backtest_data.items():
for cid, score in cid_scores.items():
short_id = truncate_commitid(cid)
commitid_agg_score[cid]["total_score"] += score
commitid_agg_score[cid]["valid_span_count"] += 1
commitid_agg_score[cid]["short_id"] = short_id
# 步骤3补充每个CommitID的已覆盖时段、缺失时段
for cid, agg_data in commitid_agg_score.items():
# 补充已覆盖时段
agg_data["covered_spans"] = cid_covered_spans.get(cid, [])
# 计算并补充缺失时段
agg_data["missing_spans"] = get_missing_spans(target_timespans, agg_data["covered_spans"])
# 计算平均得分
if agg_data["valid_span_count"] > 0:
agg_data["avg_score"] = agg_data["total_score"] / decimal.Decimal(str(agg_data["valid_span_count"]))
# 步骤4排序选最优
sorted_commitids = sorted(
commitid_agg_score.items(),
key=lambda x: x[1]["total_score"],
reverse=True
)
# 格式化结果
result_list = []
for cid, agg_data in sorted_commitids:
result_list.append({
"commitid": cid,
"short_id": agg_data["short_id"],
"total_score": float(agg_data["total_score"]),
"avg_score": float(agg_data["avg_score"]),
"valid_span_count": agg_data["valid_span_count"],
"covered_spans": agg_data["covered_spans"],
"missing_spans": agg_data["missing_spans"]
})
# 最优CommitID
best_commitid = result_list[0] if result_list else None
return best_commitid, result_list
# -------------------------- 主函数 --------------------------
def main():
# 步骤1解析命令行参数获取有效目标时段
try:
TARGET_TIMESPANS = parse_command_line_args()
except ValueError as e:
print(e)
return
# 步骤2输出目标时段信息
print(f"# 特定时段组合下最优CommitID查询报告含缺失时段")
print(f"\n## 目标时段组合(共 {len(TARGET_TIMESPANS)} 个有效时段)")
for idx, span in enumerate(TARGET_TIMESPANS, 1):
days = get_span_days(span)
print(f"{idx}. {span}(跨度:{days}天)")
# 步骤3寻找最优CommitID
print(f"\n## 查询结果")
best_commitid, result_list = find_best_commitid_in_target_spans(TARGET_TIMESPANS)
if not best_commitid:
return
# 输出最优CommitID包含缺失时段
print(f"### 表现最出色的CommitID")
print(f"- CommitID完整{best_commitid['commitid']}")
print(f"- CommitID前8位{best_commitid['short_id']}")
print(f"- 目标时段总得分:{round(best_commitid['total_score'], 4)}")
print(f"- 目标时段平均得分:{round(best_commitid['avg_score'], 4)}")
print(f"- 有效覆盖时段数:{best_commitid['valid_span_count']}/{len(TARGET_TIMESPANS)}")
print(f"- 已覆盖时段:{', '.join(best_commitid['covered_spans']) if best_commitid['covered_spans'] else ''}")
print(f"- 缺失时段:{', '.join(best_commitid['missing_spans']) if best_commitid['missing_spans'] else ''}")
# 输出前10名CommitID包含缺失时段
print(f"\n### 目标时段组合下CommitID排名前50含缺失时段")
print("| 排名 | CommitID前8位 | 总得分 | 平均得分 | 有效覆盖时段数 | 缺失时段数 | 缺失时段详情 |")
print("|------|------------------|--------|----------|----------------|------------|--------------|")
for idx, item in enumerate(result_list[:50], 1):
# 处理缺失时段详情(避免表格过长,换行展示或简化)
missing_count = len(item["missing_spans"])
missing_detail = ", ".join(item["missing_spans"]) if missing_count > 0 else ""
# 若缺失时段过多可简化为「xxx等N个」这里保留完整详情便于核对
if len(missing_detail) > 100:
missing_detail = f"{', '.join(item['missing_spans'][:3])}{missing_count}"
print(f"| {idx} | {item['short_id']} | {round(item['total_score'],4)} | {round(item['avg_score'],4)} | {item['valid_span_count']}/{len(TARGET_TIMESPANS)} | {missing_count} | {missing_detail} |")
# 可选输出前10名CommitID的完整缺失时段便于详细核对
print(f"\n### 前50名CommitID缺失时段详细清单Markdown版本")
for idx, item in enumerate(result_list[:50], 1):
print(f"\n{idx}. {item['short_id']}(缺失 {len(item['missing_spans'])} 个时段):")
if item["missing_spans"]:
for span in item["missing_spans"]:
print(f" - {span}(跨度:{get_span_days(span)}天)")
else:
print(f" - 无缺失时段(完整覆盖所有目标时段)")
# 新增JSON版本输出到文件
print(f"\n### 前50名CommitID缺失时段详细清单JSON版本")
print(f"JSON数据已保存到commitid_missing_spans.json")
import json
json_output = []
for idx, item in enumerate(result_list[:50], 1):
json_entry = {
"rank": idx,
"commitid_short": item['short_id'],
"missing_count": len(item['missing_spans']),
"missing_spans": [
{"timespan": span, "duration_days": get_span_days(span)}
for span in item['missing_spans']
]
}
json_output.append(json_entry)
# 将JSON数据写入文件
with open("commitid_missing_spans.json", "w", encoding="utf-8") as f:
json.dump(json_output, f, indent=4, ensure_ascii=False)
if __name__ == "__main__":
main()