Difficult-Rocket/libs/MCDR/serializer.py

180 lines
6.8 KiB
Python
Raw Normal View History

2022-02-08 17:15:09 +08:00
# 本文件以 GNU Lesser General Public License v3.0GNU LGPL v3) 开源协议进行授权 (谢谢狐狸写出这么好的MCDR)
# 顺便说一句,我把所有的tab都改成了空格,因为我觉得空格比tab更好看(草,后半句是github copilot自动填充的)
import copy
from abc import ABC
from enum import EnumMeta
from threading import Lock
from typing import Union, TypeVar, List, Dict, Type, get_type_hints, Any
2022-03-05 23:01:20 +08:00
from semver import VersionInfo as semver_VersionInfo
from libs.semver import VersionInfo as lib_semver_VersionInfo
2022-02-16 13:36:26 +08:00
2022-02-08 17:15:09 +08:00
"""
This part of code come from MCDReforged(https://github.com/Fallen-Breath/MCDReforged)
Very thanks to Fallen_Breath and other coder who helped MCDR worked better
GNU Lesser General Public License v3.0GNU LGPL v3)
(have some changes)
"""
__all__ = [
'serialize',
'deserialize',
'Serializable'
]
T = TypeVar('T')
def _get_type_hints(cls: Type):
try:
return get_type_hints(cls)
except:
return get_type_hints(cls, globalns={})
def _get_origin(cls: Type):
return getattr(cls, '__origin__', None)
def _get_args(cls: Type) -> tuple:
return getattr(cls, '__args__', ())
2022-03-05 23:01:20 +08:00
_BASIC_CLASSES = (type(None), bool, int, float, str, list, dict, lib_semver_VersionInfo, semver_VersionInfo)
2022-02-16 13:36:26 +08:00
def serialize(obj) -> _BASIC_CLASSES:
2022-02-08 17:15:09 +08:00
if type(obj) in (type(None), int, float, str, bool):
return obj
2022-03-05 23:01:20 +08:00
elif isinstance(obj, lib_semver_VersionInfo) or isinstance(obj, semver_VersionInfo):
2022-02-16 13:36:26 +08:00
return obj
2022-02-08 17:15:09 +08:00
elif isinstance(obj, list) or isinstance(obj, tuple):
return list(map(serialize, obj))
elif isinstance(obj, dict):
return dict(map(lambda t: (t[0], serialize(t[1])), obj.items()))
elif isinstance(obj.__class__, EnumMeta):
return obj.name
try:
attr_dict = vars(obj).copy()
# don't serialize protected fields
for attr_name in list(attr_dict.keys()):
if attr_name.startswith('_'):
attr_dict.pop(attr_name)
except:
2022-03-05 23:01:20 +08:00
raise TypeError(f'Unsupported input type {type(obj)}') from None
2022-02-08 17:15:09 +08:00
else:
return serialize(attr_dict)
def deserialize(data, cls: Type[T], *, error_at_missing=False, error_at_redundancy=False) -> T:
# in case None instead of NoneType is passed
if cls is None:
cls = type(None)
# if its type is Any, then simply return the data
if cls is Any:
return data
# Union
# Unpack Union first since the target class is not confirmed yet
elif _get_origin(cls) == Union:
for possible_cls in _get_args(cls):
try:
return deserialize(data, possible_cls, error_at_missing=error_at_missing, error_at_redundancy=error_at_redundancy)
except (TypeError, ValueError):
pass
raise TypeError('Data in type {} cannot match any candidate of target class {}'.format(type(data), cls))
# Element (None, int, float, str, list, dict)
# For list and dict, since it doesn't have any type hint, we choose to simply return the data
elif cls in _BASIC_CLASSES and type(data) is cls:
return data
# float thing
elif cls is float and isinstance(data, int):
return float(data)
# List
elif _get_origin(cls) == List[int].__origin__ and isinstance(data, list):
element_type = _get_args(cls)[0]
return list(map(lambda e: deserialize(e, element_type, error_at_missing=error_at_missing, error_at_redundancy=error_at_redundancy), data))
# Dict
elif _get_origin(cls) == Dict[int, int].__origin__ and isinstance(data, dict):
key_type = _get_args(cls)[0]
val_type = _get_args(cls)[1]
instance = {}
for key, value in data.items():
deserialized_key = deserialize(key, key_type, error_at_missing=error_at_missing, error_at_redundancy=error_at_redundancy)
deserialized_value = deserialize(value, val_type, error_at_missing=error_at_missing, error_at_redundancy=error_at_redundancy)
instance[deserialized_key] = deserialized_value
return instance
# Enum
elif isinstance(cls, EnumMeta) and isinstance(data, str):
return cls[data]
# Object
elif cls not in _BASIC_CLASSES and isinstance(cls, type) and isinstance(data, dict):
try:
result = cls()
except:
raise TypeError('Failed to construct instance of class {}'.format(type(cls)))
input_key_set = set(data.keys())
for attr_name, attr_type in _get_type_hints(cls).items():
if not attr_name.startswith('_'):
if attr_name in data:
result.__setattr__(attr_name, deserialize(data[attr_name], attr_type, error_at_missing=error_at_missing,
error_at_redundancy=error_at_redundancy))
input_key_set.remove(attr_name)
elif error_at_missing:
raise ValueError('Missing attribute {} for class {} in input object {}'.format(attr_name, cls, data))
elif hasattr(cls, attr_name):
result.__setattr__(attr_name, copy.copy(getattr(cls, attr_name)))
if error_at_redundancy and len(input_key_set) > 0:
raise ValueError('Redundancy attributes {} for class {} in input object {}'.format(input_key_set, cls, data))
if isinstance(result, Serializable):
result.on_deserialization()
return result
else:
raise TypeError('Unsupported input type: expected class {} but found data with class {}'.format(cls, type(data)))
class Serializable(ABC):
__annotations_cache: dict = None
__annotations_lock = Lock()
def __init__(self, **kwargs):
for key in kwargs.keys():
if key not in self.get_annotations_fields():
raise KeyError('Unknown key received in __init__ of class {}: {}'.format(self.__class__, key))
vars(self).update(kwargs)
@classmethod
def __get_annotation_dict(cls) -> dict:
public_fields = {}
for attr_name, attr_type in _get_type_hints(cls).items():
if not attr_name.startswith('_'):
public_fields[attr_name] = attr_type
return public_fields
@classmethod
def get_annotations_fields(cls) -> Dict[str, Type]:
with cls.__annotations_lock:
if cls.__annotations_cache is None:
cls.__annotations_cache = cls.__get_annotation_dict()
return cls.__annotations_cache
def serialize(self) -> dict:
return serialize(self)
@classmethod
def deserialize(cls, data: dict, **kwargs):
return deserialize(data, cls, **kwargs)
def update_from(self, data: dict):
vars(self).update(vars(self.deserialize(data)))
@classmethod
def get_default(cls):
return cls.deserialize({})
def on_deserialization(self):
"""
Invoked after being deserialized
"""
pass