Skip to content

Commit 37bed92

Browse files
committed
Add Type Hinting for bn128 module
1 parent 6cb9582 commit 37bed92

File tree

5 files changed

+151
-106
lines changed

5 files changed

+151
-106
lines changed

py_ecc/bn128/bn128_curve.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
from __future__ import absolute_import
22

3+
from typing import (
4+
cast,
5+
)
6+
7+
from py_ecc.typing import (
8+
Field,
9+
FQ2Point2D,
10+
FQ12Point2D,
11+
GeneralPoint,
12+
Point2D,
13+
)
14+
315
from .bn128_field_elements import (
416
field_modulus,
517
FQ,
@@ -23,7 +35,7 @@
2335
b12 = FQ12([3] + [0] * 11)
2436

2537
# Generator for curve over FQ
26-
G1 = (FQ(1), FQ(2))
38+
G1 = cast(Point2D[FQ], (FQ(1), FQ(2)))
2739
# Generator for twisted curve over FQ2
2840
G2 = (
2941
FQ2([
@@ -42,33 +54,34 @@
4254

4355

4456
# Check if a point is the point at infinity
45-
def is_inf(pt):
57+
def is_inf(pt: GeneralPoint[Field]) -> bool:
4658
return pt is None
4759

4860

4961
# Check that a point is on the curve defined by y**2 == x**3 + b
50-
def is_on_curve(pt, b):
62+
def is_on_curve(pt: Point2D[Field], b: Field) -> bool:
5163
if is_inf(pt):
5264
return True
5365
x, y = pt
5466
return y**2 - x**3 == b
5567

5668

5769
assert is_on_curve(G1, b)
58-
assert is_on_curve(G2, b2)
70+
assert is_on_curve(cast(Point2D[FQ2], G2), b2)
5971

6072

6173
# Elliptic curve doubling
62-
def double(pt):
74+
def double(pt: Point2D[Field]) -> Point2D[Field]:
6375
x, y = pt
6476
m = 3 * x**2 / (2 * y)
6577
newx = m**2 - 2 * x
6678
newy = -m * newx + m * x - y
67-
return newx, newy
79+
return (newx, newy)
6880

6981

7082
# Elliptic curve addition
71-
def add(p1, p2):
83+
def add(p1: Point2D[Field],
84+
p2: Point2D[Field]) -> Point2D[Field]:
7285
if p1 is None or p2 is None:
7386
return p1 if p2 is None else p2
7487
x1, y1 = p1
@@ -86,7 +99,7 @@ def add(p1, p2):
8699

87100

88101
# Elliptic curve point multiplication
89-
def multiply(pt, n):
102+
def multiply(pt: Point2D[Field], n: int) -> Point2D[Field]:
90103
if n == 0:
91104
return None
92105
elif n == 1:
@@ -97,7 +110,7 @@ def multiply(pt, n):
97110
return add(multiply(double(pt), int(n // 2)), pt)
98111

99112

100-
def eq(p1, p2):
113+
def eq(p1: GeneralPoint[Field], p2: GeneralPoint[Field]) -> bool:
101114
return p1 == p2
102115

103116

@@ -106,14 +119,14 @@ def eq(p1, p2):
106119

107120

108121
# Convert P => -P
109-
def neg(pt):
122+
def neg(pt: Point2D[Field]) -> Point2D[Field]:
110123
if pt is None:
111124
return None
112125
x, y = pt
113126
return (x, -y)
114127

115128

116-
def twist(pt):
129+
def twist(pt: FQ2Point2D) -> FQ12Point2D:
117130
if pt is None:
118131
return None
119132
_x, _y = pt
@@ -122,12 +135,12 @@ def twist(pt):
122135
ycoeffs = [_y.coeffs[0] - _y.coeffs[1] * 9, _y.coeffs[1]]
123136
# Isomorphism into subfield of Z[p] / w**12 - 18 * w**6 + 82,
124137
# where w**6 = x
125-
nx = FQ12([xcoeffs[0]] + [0] * 5 + [xcoeffs[1]] + [0] * 5)
126-
ny = FQ12([ycoeffs[0]] + [0] * 5 + [ycoeffs[1]] + [0] * 5)
138+
nx = FQ12([int(xcoeffs[0])] + [0] * 5 + [int(xcoeffs[1])] + [0] * 5)
139+
ny = FQ12([int(ycoeffs[0])] + [0] * 5 + [int(ycoeffs[1])] + [0] * 5)
127140
# Divide x coord by w**2 and y coord by w**3
128-
return (nx * w ** 2, ny * w**3)
141+
return cast(FQ12Point2D, (nx * w ** 2, ny * w**3))
129142

130143

131-
G12 = twist(G2)
144+
G12 = twist(cast(FQ2Point2D, G2))
132145
# Check that the twist creates a point that is on the curve
133146
assert is_on_curve(G12, b12)

0 commit comments

Comments
 (0)