sr-download/sr_download/sql/xml_parse.py

119 lines
3.5 KiB
Python

from __future__ import annotations
import time
import xml.etree.ElementTree as ET
import psycopg2
import tomli
from lib_not_dr import loggers
with open("../../config.toml", "rb") as f:
CONFIG = tomli.load(f)
logger = loggers.config.get_logger("xml_parse")
# logger.global_level = 10
result_logger = loggers.config.get_logger("result")
file = loggers.outstream.FileCacheOutputStream(file_name="result.log")
result_logger.add_output(file)
def get_db():
connect = psycopg2.connect(
CONFIG["db"]["url"]
)
return connect
def parse_part_types(xml_file):
tree = ET.parse(xml_file)
root = tree.getroot()
part_dict = {}
# 遍历所有 PartType 元素
for part in root.findall('PartType'):
part_id = part.get('id')
part_mass = float(part.get('mass')) # type: ignore
part_dict[part_id] = part_mass * 500 # 缩放因子
return part_dict
def fetch_data(db_cur, offset, limit):
xml_fetch = f"""
WITH data AS (
SELECT save_id as id, data
FROM public.full_data
WHERE "save_type" = 'ship'
AND full_data.xml_tested
AND save_id >= 1200000
ORDER BY save_id ASC
LIMIT {limit} OFFSET {offset}
),
parts_data AS (
SELECT data.id, parts.part_type
FROM data,
XMLTABLE (
'//Ship/Parts/Part'
PASSING BY VALUE xmlparse(document data."data")
COLUMNS part_type text PATH '@partType',
part_id text PATH '@id'
) AS parts
)
SELECT id,
array_agg(part_type) AS part_types,
array_agg(part_count) AS part_counts
FROM (
SELECT id, part_type, COUNT(part_type) AS part_count
FROM parts_data
GROUP BY id, part_type
) AS counted_parts
GROUP BY id;
"""
db_cur.execute(xml_fetch)
return db_cur.fetchall()
def main():
part_mass = parse_part_types("PartList.xml")
logger.debug(part_mass)
db = get_db()
db_cur = db.cursor()
offset = 0
limit = 100
target_mass = 775150
delta = 10000
start_time = time.time()
while True:
datas = fetch_data(db_cur, offset, limit)
if not datas:
break
for data in datas:
logger.debug(data)
# (id, ['pod-1', 'parachute-1', 'lander-1', 'engine-1', 'fuselage-1', 'fueltank-0'], [1, 1, 2, 1, 1, 1])
# 解析这玩意
# ship_mass = sum(part_mass[part] * data[2][idx] for part, idx in enumerate(data[1]))
ship_mass = sum(part_mass[part] * count if part in part_mass else 0 for part, count in zip(data[1], data[2]))
logger.debug(f"{data[0]}: {ship_mass}")
# 如果在误差范围内,则输出
if target_mass - delta <= ship_mass <= target_mass + delta:
result_logger.info(f"Ship ID: {data[0]}")
result_logger.info(f"Mass: {ship_mass}")
result_logger.info(f"Part Types: {data[1]}")
result_logger.info(f"Part Counts: {data[2]}")
result_logger.info("=" * 20)
logger.warn(f"Ship ID: {data[0]}")
logger.warn(f"Mass: {ship_mass}")
logger.warn(f"Part Types: {data[1]}")
logger.warn(f"Part Counts: {data[2]}")
logger.warn("=" * 20)
offset += limit
# 效率
delta_t = time.time() - start_time
speed = limit / delta_t
logger.info(f"Offset: {offset} Speed: {speed:.2f} ships/s({delta_t:.2f}s) - {data[0]}")
start_time = time.time()
main()