Skip to content

Commit 8773252

Browse files
Type annotation work in manim/mobject/geometry/ (#3961)
Fixes typehints in manim.mobject.geometry, and enables type checking of those modules. Part of #3375
1 parent a395ffd commit 8773252

14 files changed

+414
-260
lines changed

manim/mobject/geometry/arc.py

Lines changed: 81 additions & 58 deletions
Large diffs are not rendered by default.

manim/mobject/geometry/boolean_ops.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
from manim.mobject.types.vectorized_mobject import VMobject
1414

1515
if TYPE_CHECKING:
16-
from manim.typing import Point2D_Array, Point3D_Array
16+
from typing import Any
17+
18+
from manim.typing import InternalPoint3D_Array, Point2D_Array
1719

1820
from ...constants import RendererType
1921

@@ -30,7 +32,7 @@ def _convert_2d_to_3d_array(
3032
self,
3133
points: Point2D_Array,
3234
z_dim: float = 0.0,
33-
) -> Point3D_Array:
35+
) -> InternalPoint3D_Array:
3436
"""Converts an iterable with coordinates in 2D to 3D by adding
3537
:attr:`z_dim` as the Z coordinate.
3638
@@ -51,13 +53,14 @@ def _convert_2d_to_3d_array(
5153
>>> a = _BooleanOps()
5254
>>> p = [(1, 2), (3, 4)]
5355
>>> a._convert_2d_to_3d_array(p)
54-
[array([1., 2., 0.]), array([3., 4., 0.])]
56+
array([[1., 2., 0.],
57+
[3., 4., 0.]])
5558
"""
56-
points = list(points)
57-
for i, point in enumerate(points):
59+
list_of_points = list(points)
60+
for i, point in enumerate(list_of_points):
5861
if len(point) == 2:
59-
points[i] = np.array(list(point) + [z_dim])
60-
return points
62+
list_of_points[i] = np.array(list(point) + [z_dim])
63+
return np.asarray(list_of_points)
6164

6265
def _convert_vmobject_to_skia_path(self, vmobject: VMobject) -> SkiaPath:
6366
"""Converts a :class:`~.VMobject` to SkiaPath. This method only works for
@@ -95,7 +98,7 @@ def _convert_vmobject_to_skia_path(self, vmobject: VMobject) -> SkiaPath:
9598
if vmobject.consider_points_equals(subpath[0], subpath[-1]):
9699
path.close()
97100
elif config.renderer == RendererType.CAIRO:
98-
subpaths = vmobject.gen_subpaths_from_points_2d(points)
101+
subpaths = vmobject.gen_subpaths_from_points_2d(points) # type: ignore[assignment]
99102
for subpath in subpaths:
100103
quads = vmobject.gen_cubic_bezier_tuples_from_points(subpath)
101104
start = subpath[0]
@@ -177,7 +180,7 @@ def construct(self):
177180
178181
"""
179182

180-
def __init__(self, *vmobjects: VMobject, **kwargs) -> None:
183+
def __init__(self, *vmobjects: VMobject, **kwargs: Any) -> None:
181184
if len(vmobjects) < 2:
182185
raise ValueError("At least 2 mobjects needed for Union.")
183186
super().__init__(**kwargs)
@@ -216,7 +219,7 @@ def construct(self):
216219
217220
"""
218221

219-
def __init__(self, subject: VMobject, clip: VMobject, **kwargs) -> None:
222+
def __init__(self, subject: VMobject, clip: VMobject, **kwargs: Any) -> None:
220223
super().__init__(**kwargs)
221224
outpen = SkiaPath()
222225
difference(
@@ -258,7 +261,7 @@ def construct(self):
258261
259262
"""
260263

261-
def __init__(self, *vmobjects: VMobject, **kwargs) -> None:
264+
def __init__(self, *vmobjects: VMobject, **kwargs: Any) -> None:
262265
if len(vmobjects) < 2:
263266
raise ValueError("At least 2 mobjects needed for Intersection.")
264267

@@ -311,7 +314,7 @@ def construct(self):
311314
312315
"""
313316

314-
def __init__(self, subject: VMobject, clip: VMobject, **kwargs) -> None:
317+
def __init__(self, subject: VMobject, clip: VMobject, **kwargs: Any) -> None:
315318
super().__init__(**kwargs)
316319
outpen = SkiaPath()
317320
xor(

manim/mobject/geometry/labeled.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
__all__ = ["LabeledLine", "LabeledArrow"]
66

7+
from typing import TYPE_CHECKING
8+
79
from manim.constants import *
810
from manim.mobject.geometry.line import Arrow, Line
911
from manim.mobject.geometry.shape_matchers import (
@@ -14,6 +16,9 @@
1416
from manim.mobject.text.text_mobject import Text
1517
from manim.utils.color import WHITE, ManimColor, ParsableManimColor
1618

19+
if TYPE_CHECKING:
20+
from typing import Any
21+
1722

1823
class LabeledLine(Line):
1924
"""Constructs a line containing a label box somewhere along its length.
@@ -67,17 +72,19 @@ def __init__(
6772
font_size: float = DEFAULT_FONT_SIZE,
6873
label_color: ParsableManimColor = WHITE,
6974
label_frame: bool = True,
70-
frame_fill_color: ParsableManimColor = None,
75+
frame_fill_color: ParsableManimColor | None = None,
7176
frame_fill_opacity: float = 1,
72-
*args,
73-
**kwargs,
77+
*args: Any,
78+
**kwargs: Any,
7479
) -> None:
7580
label_color = ManimColor(label_color)
7681
frame_fill_color = ManimColor(frame_fill_color)
7782
if isinstance(label, str):
7883
from manim import MathTex
7984

80-
rendered_label = MathTex(label, color=label_color, font_size=font_size)
85+
rendered_label: Tex | MathTex | Text = MathTex(
86+
label, color=label_color, font_size=font_size
87+
)
8188
else:
8289
rendered_label = label
8390

@@ -149,7 +156,7 @@ def construct(self):
149156

150157
def __init__(
151158
self,
152-
*args,
153-
**kwargs,
159+
*args: Any,
160+
**kwargs: Any,
154161
) -> None:
155162
super().__init__(*args, **kwargs)

0 commit comments

Comments
 (0)