import pymysql import decimal from collections import defaultdict from datetime import datetime import argparse # -------------------------- 配置项 -------------------------- DB_CONFIG = { "host": "192.168.1.215", # 数据库地址 "port": 3306, # 端口 "user": "freqai", # 用户名 "password": "Abcd@1234", # 密码 "database": "freqai", # 数据库名 "charset": "utf8mb4" # 字符集 } MIN_SPAN_DAYS = 5 CUTOFF_DATE = datetime(2025, 9, 10) SHORT_SPAN_THRESHOLD = 50 # -------------------------- 工具函数 -------------------------- def truncate_commitid(commitid): return commitid[:8] if commitid else "" def format_timespan(from_str, to_str): from_date = from_str.split(" ")[0].replace("-", "") to_date = to_str.split(" ")[0].replace("-", "") return f"{from_date}-{to_date}" def parse_formatted_timespan(formatted_span): """解析格式化时段为datetime对象,异常时返回None""" try: from_date_str, to_date_str = formatted_span.split("-") from_date = datetime.strptime(from_date_str, "%Y%m%d") to_date = datetime.strptime(to_date_str, "%Y%m%d") return from_date, to_date except (ValueError, IndexError): return None, None def get_span_days(formatted_span): from_date, to_date = parse_formatted_timespan(formatted_span) if not from_date or not to_date: return 0 return (to_date - from_date).days + 1 def is_span_valid(formatted_span): from_date, _ = parse_formatted_timespan(formatted_span) if not from_date: return False span_days = get_span_days(formatted_span) if from_date < CUTOFF_DATE and span_days < SHORT_SPAN_THRESHOLD: return False return True # 新增:计算缺失时段(目标时段 - 已覆盖时段) def get_missing_spans(target_spans, covered_spans): """返回该CommitID缺失的目标时段列表""" return [span for span in target_spans if span not in covered_spans] def validate_timespan_format(formatted_span): """校验时段格式是否为 YYYYMMDD-YYYYMMDD,且开始日期≤结束日期""" from_date, to_date = parse_formatted_timespan(formatted_span) if not from_date or not to_date: return False if from_date > to_date: return False return True # -------------------------- 命令行参数解析函数 -------------------------- def parse_command_line_args(): """解析命令行参数,提取 --timespans 传入的时段列表""" parser = argparse.ArgumentParser(description="查询指定时段组合下表现最优的CommitID") parser.add_argument( "--timespans", type=str, required=True, help="指定目标时段组合,格式:YYYYMMDD-YYYYMMDD,YYYYMMDD-YYYYMMDD(多个时段用逗号分隔)" ) args = parser.parse_args() # 处理参数:分割逗号,去重,过滤空字符串 target_timespans = [span.strip() for span in args.timespans.split(",") if span.strip()] # 校验每个时段的格式合法性 valid_timespans = [] invalid_timespans = [] for span in target_timespans: if validate_timespan_format(span): valid_timespans.append(span) else: invalid_timespans.append(span) # 输出校验结果 if invalid_timespans: print(f"⚠️ 以下时段格式不合法(需为YYYYMMDD-YYYYMMDD,且开始日期≤结束日期),已忽略:") for span in invalid_timespans: print(f" - {span}") if not valid_timespans: raise ValueError("❌ 无有效时段传入,程序终止") return valid_timespans # -------------------------- 数据库操作 -------------------------- def get_db_connection(): try: return pymysql.connect(**DB_CONFIG) except pymysql.MySQLError as e: print(f"数据库连接失败: {e}") raise def get_backtest_data_in_target_spans(target_timespans): """提取目标时段组合内的所有回测数据(新增:保留每个CommitID的已覆盖时段)""" conn = get_db_connection() cursor = conn.cursor(pymysql.cursors.DictCursor) # 先把目标时段转换成「原始格式」(YYYY-MM-DD HH:MM:SS),用于数据库查询 target_span_original = {} for formatted_span in target_timespans: from_date, to_date = parse_formatted_timespan(formatted_span) from_str = from_date.strftime("%Y-%m-%d 00:00:00") to_str = to_date.strftime("%Y-%m-%d 00:00:00") target_span_original[formatted_span] = (from_str, to_str) # 遍历目标时段,查询每个时段的回测数据 target_backtest_data = defaultdict(dict) # 新增:记录每个CommitID的已覆盖时段 cid_covered_spans = defaultdict(list) for formatted_span, (from_str, to_str) in target_span_original.items(): sql = """ SELECT JSON_UNQUOTE(JSON_EXTRACT(result, '$.commitid')) AS commitid, IF( JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.total_profit')) IS NULL OR JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.sqn')) IS NULL, 0, SIGN(CAST(JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.total_profit')) AS DECIMAL(10,4))) * CAST(JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.total_profit')) AS DECIMAL(10,4)) * CAST(JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.sqn')) AS DECIMAL(10,4)) ) AS corrected_product FROM backtestresult WHERE JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.backtesting_from_str')) = %s AND JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.backtesting_to_str')) = %s; """ try: cursor.execute(sql, (from_str, to_str)) records = cursor.fetchall() # 计算该时段内所有commitid的相对排名 span_scores = {} for record in records: cid = record["commitid"] score = decimal.Decimal(str(record["corrected_product"])) if cid: span_scores[cid] = score # 按得分排序,计算相对排名分数 if span_scores: sorted_cids = sorted(span_scores.items(), key=lambda x: x[1], reverse=True) total_count = len(sorted_cids) # 为每个commitid分配相对排名分数(0-100分) for idx, (cid, score) in enumerate(sorted_cids): # 使用百分位排名:排名越靠前,分数越高 # 例如:第1名=100分,第2名=90分,...,第10名=10分 rank_score = max(0, 100 - idx * 10) # 存储相对排名分数 if cid not in target_backtest_data[formatted_span] or rank_score > target_backtest_data[formatted_span][cid]: target_backtest_data[formatted_span][cid] = rank_score # 记录该CommitID覆盖了当前时段(去重,避免重复添加) if formatted_span not in cid_covered_spans[cid]: cid_covered_spans[cid].append(formatted_span) except pymysql.MySQLError as e: print(f"查询目标时段 {formatted_span} 失败: {e}") continue cursor.close() conn.close() return target_backtest_data, cid_covered_spans # -------------------------- 核心函数:寻找目标时段组合下的最优CommitID -------------------------- def find_best_commitid_in_target_spans(target_timespans): """寻找指定时段组合下的最优CommitID(包含缺失时段)""" # 步骤1:获取目标时段回测数据 + 每个CommitID的已覆盖时段 target_backtest_data, cid_covered_spans = get_backtest_data_in_target_spans(target_timespans) if not target_backtest_data: print("❌ 目标时段组合内无回测数据") return None, [] # 步骤2:聚合CommitID表现 commitid_agg_score = defaultdict(lambda: { "total_score": decimal.Decimal("0"), "avg_score": decimal.Decimal("0"), "valid_span_count": 0, "short_id": "", "covered_spans": [], # 新增:已覆盖时段列表 "missing_spans": [] # 新增:缺失时段列表 }) for formatted_span, cid_scores in target_backtest_data.items(): for cid, score in cid_scores.items(): short_id = truncate_commitid(cid) commitid_agg_score[cid]["total_score"] += score commitid_agg_score[cid]["valid_span_count"] += 1 commitid_agg_score[cid]["short_id"] = short_id # 步骤3:补充每个CommitID的已覆盖时段、缺失时段 for cid, agg_data in commitid_agg_score.items(): # 补充已覆盖时段 agg_data["covered_spans"] = cid_covered_spans.get(cid, []) # 计算并补充缺失时段 agg_data["missing_spans"] = get_missing_spans(target_timespans, agg_data["covered_spans"]) # 计算平均得分 if agg_data["valid_span_count"] > 0: agg_data["avg_score"] = agg_data["total_score"] / decimal.Decimal(str(agg_data["valid_span_count"])) # 步骤4:排序选最优 sorted_commitids = sorted( commitid_agg_score.items(), key=lambda x: x[1]["total_score"], reverse=True ) # 格式化结果 result_list = [] for cid, agg_data in sorted_commitids: result_list.append({ "commitid": cid, "short_id": agg_data["short_id"], "total_score": float(agg_data["total_score"]), "avg_score": float(agg_data["avg_score"]), "valid_span_count": agg_data["valid_span_count"], "covered_spans": agg_data["covered_spans"], "missing_spans": agg_data["missing_spans"] }) # 最优CommitID best_commitid = result_list[0] if result_list else None return best_commitid, result_list # -------------------------- 主函数 -------------------------- def main(): # 步骤1:解析命令行参数,获取有效目标时段 try: TARGET_TIMESPANS = parse_command_line_args() except ValueError as e: print(e) return # 步骤2:输出目标时段信息 print(f"# 特定时段组合下最优CommitID查询报告(含缺失时段)") print(f"\n## 目标时段组合(共 {len(TARGET_TIMESPANS)} 个有效时段)") for idx, span in enumerate(TARGET_TIMESPANS, 1): days = get_span_days(span) print(f"{idx}. {span}(跨度:{days}天)") # 步骤3:寻找最优CommitID print(f"\n## 查询结果") best_commitid, result_list = find_best_commitid_in_target_spans(TARGET_TIMESPANS) if not best_commitid: return # 输出最优CommitID(包含缺失时段) print(f"### 表现最出色的CommitID") print(f"- CommitID(完整):{best_commitid['commitid']}") print(f"- CommitID(前8位):{best_commitid['short_id']}") print(f"- 目标时段总得分:{round(best_commitid['total_score'], 4)}") print(f"- 目标时段平均得分:{round(best_commitid['avg_score'], 4)}") print(f"- 有效覆盖时段数:{best_commitid['valid_span_count']}/{len(TARGET_TIMESPANS)}") print(f"- 已覆盖时段:{', '.join(best_commitid['covered_spans']) if best_commitid['covered_spans'] else '无'}") print(f"- 缺失时段:{', '.join(best_commitid['missing_spans']) if best_commitid['missing_spans'] else '无'}") # 输出前10名CommitID(包含缺失时段) print(f"\n### 目标时段组合下CommitID排名(前50,含缺失时段)") print("| 排名 | CommitID(前8位) | 总得分 | 平均得分 | 有效覆盖时段数 | 缺失时段数 | 缺失时段详情 |") print("|------|------------------|--------|----------|----------------|------------|--------------|") for idx, item in enumerate(result_list[:50], 1): # 处理缺失时段详情(避免表格过长,换行展示或简化) missing_count = len(item["missing_spans"]) missing_detail = ", ".join(item["missing_spans"]) if missing_count > 0 else "无" # 若缺失时段过多,可简化为「xxx等N个」,这里保留完整详情便于核对 if len(missing_detail) > 100: missing_detail = f"{', '.join(item['missing_spans'][:3])} 等 {missing_count} 个" print(f"| {idx} | {item['short_id']} | {round(item['total_score'],4)} | {round(item['avg_score'],4)} | {item['valid_span_count']}/{len(TARGET_TIMESPANS)} | {missing_count} | {missing_detail} |") # 可选:输出前10名CommitID的完整缺失时段(便于详细核对) print(f"\n### 前50名CommitID缺失时段详细清单(Markdown版本)") for idx, item in enumerate(result_list[:50], 1): print(f"\n{idx}. {item['short_id']}(缺失 {len(item['missing_spans'])} 个时段):") if item["missing_spans"]: for span in item["missing_spans"]: print(f" - {span}(跨度:{get_span_days(span)}天)") else: print(f" - 无缺失时段(完整覆盖所有目标时段)") # 新增:JSON版本输出到文件 print(f"\n### 前50名CommitID缺失时段详细清单(JSON版本)") print(f"JSON数据已保存到:commitid_missing_spans.json") import json json_output = [] for idx, item in enumerate(result_list[:50], 1): json_entry = { "rank": idx, "commitid_short": item['short_id'], "missing_count": len(item['missing_spans']), "missing_spans": [ {"timespan": span, "duration_days": get_span_days(span)} for span in item['missing_spans'] ] } json_output.append(json_entry) # 将JSON数据写入文件 with open("commitid_missing_spans.json", "w", encoding="utf-8") as f: json.dump(json_output, f, indent=4, ensure_ascii=False) if __name__ == "__main__": main()