use richcmp insted

This commit is contained in:
shenjack 2023-03-28 23:53:36 +08:00
parent 035ad7bf3e
commit f059466e53
2 changed files with 47 additions and 25 deletions

View File

@ -7,6 +7,7 @@
*/
pub mod python_class {
use pyo3::class::basic::CompareOp;
use pyo3::prelude::*;
use crate::math::matrix::{Matrix3, Matrix4};
@ -95,8 +96,15 @@ pub mod python_class {
};
}
fn __eq__(&self, other: &Self) -> bool {
return self.data == other.data;
fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<bool> {
match op {
CompareOp::Lt => Ok(self.data < other.data),
CompareOp::Le => Ok(self.data <= other.data),
CompareOp::Eq => Ok(self.data == other.data),
CompareOp::Ne => Ok(self.data != other.data),
CompareOp::Gt => Ok(self.data > other.data),
CompareOp::Ge => Ok(self.data >= other.data),
}
}
fn __repr__(&self) -> String {
@ -165,8 +173,15 @@ pub mod python_class {
};
}
fn __eq__(&self, other: &Self) -> bool {
return self.data == other.data;
fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<bool> {
match op {
CompareOp::Lt => Ok(self.data < other.data),
CompareOp::Le => Ok(self.data <= other.data),
CompareOp::Eq => Ok(self.data == other.data),
CompareOp::Ne => Ok(self.data != other.data),
CompareOp::Gt => Ok(self.data > other.data),
CompareOp::Ge => Ok(self.data >= other.data),
}
}
fn __repr__(&self) -> String {
@ -248,8 +263,15 @@ pub mod python_class {
};
}
fn __eq__(&self, other: &Self) -> bool {
return self.data == other.data;
fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<bool> {
match op {
CompareOp::Lt => Ok(self.data < other.data),
CompareOp::Le => Ok(self.data <= other.data),
CompareOp::Eq => Ok(self.data == other.data),
CompareOp::Ne => Ok(self.data != other.data),
CompareOp::Gt => Ok(self.data > other.data),
CompareOp::Ge => Ok(self.data >= other.data),
}
}
fn __repr__(&self) -> String {

View File

@ -44,28 +44,28 @@ class TestVector(unittest.TestCase):
vec2, vec3, vec4 = gen_random_vector()
vec2_1, vec3_1, vec4_1 = gen_random_vector()
print(f"{vec2=} {vec2_1=}")
print(f"{vec3=} {vec3_1=}")
print(f"{vec4=} {vec4_1=}")
print('test add')
self.assertEqual(vec2 + vec2_1, Vector2_rs(vec2.x + vec2_1.x, vec2.y + vec2_1.y))
self.assertEqual(vec3 + vec3_1, Vector3_rs(vec3.x + vec3_1.x, vec3.y + vec3_1.y, vec3.z + vec3_1.z))
self.assertEqual(vec4 + vec4_1, Vector4_rs(vec4.x + vec4_1.x, vec4.y + vec4_1.y, vec4.z + vec4_1.z, vec4.w + vec4_1.w))
print(f"{vec2 + vec2_1=}")
print(f"{vec3 + vec3_1=}")
print(f"{vec4 + vec4_1=}")
print('test sub')
self.assertEqual(vec2 - vec2_1, Vector2_rs(vec2.x - vec2_1.x, vec2.y - vec2_1.y))
self.assertEqual(vec3 - vec3_1, Vector3_rs(vec3.x - vec3_1.x, vec3.y - vec3_1.y, vec3.z - vec3_1.z))
self.assertEqual(vec4 - vec4_1, Vector4_rs(vec4.x - vec4_1.x, vec4.y - vec4_1.y, vec4.z - vec4_1.z, vec4.w - vec4_1.w))
print(f"{vec2 - vec2_1=}")
print(f"{vec3 - vec3_1=}")
print(f"{vec4 - vec4_1=}")
print('test mul')
self.assertEqual(vec2 * vec2_1, Vector2_rs(vec2.x * vec2_1.x, vec2.y * vec2_1.y))
self.assertEqual(vec3 * vec3_1, Vector3_rs(vec3.x * vec3_1.x, vec3.y * vec3_1.y, vec3.z * vec3_1.z))
self.assertEqual(vec4 * vec4_1, Vector4_rs(vec4.x * vec4_1.x, vec4.y * vec4_1.y, vec4.z * vec4_1.z, vec4.w * vec4_1.w))
print(f"{vec2 * vec2_1=}")
print(f"{vec3 * vec3_1=}")
print(f"{vec4 * vec4_1=}")
print('test true_div')
self.assertEqual(vec2 / vec2_1, Vector2_rs(vec2.x / vec2_1.x, vec2.y / vec2_1.y))
self.assertEqual(vec3 / vec3_1, Vector3_rs(vec3.x / vec3_1.x, vec3.y / vec3_1.y, vec3.z / vec3_1.z))
self.assertEqual(vec4 / vec4_1, Vector4_rs(vec4.x / vec4_1.x, vec4.y / vec4_1.y, vec4.z / vec4_1.z, vec4.w / vec4_1.w))
print(f"{vec2 / vec2_1=}")
print(f"{vec3 / vec3_1=}")
print(f"{vec4 / vec4_1=}")
print('test floor_div')
self.assertEqual(vec2 // vec2_1, Vector2_rs(vec2.x // vec2_1.x, vec2.y // vec2_1.y))
self.assertEqual(vec3 // vec3_1, Vector3_rs(vec3.x // vec3_1.x, vec3.y // vec3_1.y, vec3.z // vec3_1.z))
self.assertEqual(vec4 // vec4_1, Vector4_rs(vec4.x // vec4_1.x, vec4.y // vec4_1.y, vec4.z // vec4_1.z, vec4.w // vec4_1.w))
print(f"{vec2 // vec2_1=}")
print(f"{vec3 // vec3_1=}")
print(f"{vec4 // vec4_1=}")