diff --git a/libs/pyglet_rs/src/src/pymath.rs b/libs/pyglet_rs/src/src/pymath.rs index 046412f..6bc5a0b 100644 --- a/libs/pyglet_rs/src/src/pymath.rs +++ b/libs/pyglet_rs/src/src/pymath.rs @@ -7,6 +7,7 @@ */ pub mod python_class { + use pyo3::class::basic::CompareOp; use pyo3::prelude::*; use crate::math::matrix::{Matrix3, Matrix4}; @@ -95,8 +96,15 @@ pub mod python_class { }; } - fn __eq__(&self, other: &Self) -> bool { - return self.data == other.data; + fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { + match op { + CompareOp::Lt => Ok(self.data < other.data), + CompareOp::Le => Ok(self.data <= other.data), + CompareOp::Eq => Ok(self.data == other.data), + CompareOp::Ne => Ok(self.data != other.data), + CompareOp::Gt => Ok(self.data > other.data), + CompareOp::Ge => Ok(self.data >= other.data), + } } fn __repr__(&self) -> String { @@ -165,8 +173,15 @@ pub mod python_class { }; } - fn __eq__(&self, other: &Self) -> bool { - return self.data == other.data; + fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { + match op { + CompareOp::Lt => Ok(self.data < other.data), + CompareOp::Le => Ok(self.data <= other.data), + CompareOp::Eq => Ok(self.data == other.data), + CompareOp::Ne => Ok(self.data != other.data), + CompareOp::Gt => Ok(self.data > other.data), + CompareOp::Ge => Ok(self.data >= other.data), + } } fn __repr__(&self) -> String { @@ -248,8 +263,15 @@ pub mod python_class { }; } - fn __eq__(&self, other: &Self) -> bool { - return self.data == other.data; + fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { + match op { + CompareOp::Lt => Ok(self.data < other.data), + CompareOp::Le => Ok(self.data <= other.data), + CompareOp::Eq => Ok(self.data == other.data), + CompareOp::Ne => Ok(self.data != other.data), + CompareOp::Gt => Ok(self.data > other.data), + CompareOp::Ge => Ok(self.data >= other.data), + } } fn __repr__(&self) -> String { diff --git a/libs/pyglet_rs/src/test/vector.py b/libs/pyglet_rs/src/test/vector.py index c5f26cb..58a749b 100644 --- a/libs/pyglet_rs/src/test/vector.py +++ b/libs/pyglet_rs/src/test/vector.py @@ -44,28 +44,28 @@ class TestVector(unittest.TestCase): vec2, vec3, vec4 = gen_random_vector() vec2_1, vec3_1, vec4_1 = gen_random_vector() - print(f"{vec2=} {vec2_1=}") - print(f"{vec3=} {vec3_1=}") - print(f"{vec4=} {vec4_1=}") - + print('test add') self.assertEqual(vec2 + vec2_1, Vector2_rs(vec2.x + vec2_1.x, vec2.y + vec2_1.y)) + self.assertEqual(vec3 + vec3_1, Vector3_rs(vec3.x + vec3_1.x, vec3.y + vec3_1.y, vec3.z + vec3_1.z)) + self.assertEqual(vec4 + vec4_1, Vector4_rs(vec4.x + vec4_1.x, vec4.y + vec4_1.y, vec4.z + vec4_1.z, vec4.w + vec4_1.w)) - print(f"{vec2 + vec2_1=}") - print(f"{vec3 + vec3_1=}") - print(f"{vec4 + vec4_1=}") + print('test sub') + self.assertEqual(vec2 - vec2_1, Vector2_rs(vec2.x - vec2_1.x, vec2.y - vec2_1.y)) + self.assertEqual(vec3 - vec3_1, Vector3_rs(vec3.x - vec3_1.x, vec3.y - vec3_1.y, vec3.z - vec3_1.z)) + self.assertEqual(vec4 - vec4_1, Vector4_rs(vec4.x - vec4_1.x, vec4.y - vec4_1.y, vec4.z - vec4_1.z, vec4.w - vec4_1.w)) - print(f"{vec2 - vec2_1=}") - print(f"{vec3 - vec3_1=}") - print(f"{vec4 - vec4_1=}") + print('test mul') + self.assertEqual(vec2 * vec2_1, Vector2_rs(vec2.x * vec2_1.x, vec2.y * vec2_1.y)) + self.assertEqual(vec3 * vec3_1, Vector3_rs(vec3.x * vec3_1.x, vec3.y * vec3_1.y, vec3.z * vec3_1.z)) + self.assertEqual(vec4 * vec4_1, Vector4_rs(vec4.x * vec4_1.x, vec4.y * vec4_1.y, vec4.z * vec4_1.z, vec4.w * vec4_1.w)) - print(f"{vec2 * vec2_1=}") - print(f"{vec3 * vec3_1=}") - print(f"{vec4 * vec4_1=}") + print('test true_div') + self.assertEqual(vec2 / vec2_1, Vector2_rs(vec2.x / vec2_1.x, vec2.y / vec2_1.y)) + self.assertEqual(vec3 / vec3_1, Vector3_rs(vec3.x / vec3_1.x, vec3.y / vec3_1.y, vec3.z / vec3_1.z)) + self.assertEqual(vec4 / vec4_1, Vector4_rs(vec4.x / vec4_1.x, vec4.y / vec4_1.y, vec4.z / vec4_1.z, vec4.w / vec4_1.w)) - print(f"{vec2 / vec2_1=}") - print(f"{vec3 / vec3_1=}") - print(f"{vec4 / vec4_1=}") + print('test floor_div') + self.assertEqual(vec2 // vec2_1, Vector2_rs(vec2.x // vec2_1.x, vec2.y // vec2_1.y)) + self.assertEqual(vec3 // vec3_1, Vector3_rs(vec3.x // vec3_1.x, vec3.y // vec3_1.y, vec3.z // vec3_1.z)) + self.assertEqual(vec4 // vec4_1, Vector4_rs(vec4.x // vec4_1.x, vec4.y // vec4_1.y, vec4.z // vec4_1.z, vec4.w // vec4_1.w)) - print(f"{vec2 // vec2_1=}") - print(f"{vec3 // vec3_1=}") - print(f"{vec4 // vec4_1=}") \ No newline at end of file