2019年4月8日月曜日

開発環境

Programming Bitcoin: Learn How to Program Bitcoin from Scratch (Jimmy Song(著)、O'Reilly Media)のChapter 3(Elliptic Curve Cryptography)、Scalar Multiplication for Elliptic Curves、Exercises 4(49)の解答を求めてみる。

コード

Python 3

ecc_test.py

#!/usr/bin/env python3
from unittest import TestCase, main
from ecc import FieldElement, Point


class PointTest(TestCase):
    def setUp(self):
        pass

    def tearDown(self):
        pass

    def test_ne1(self):
        p1 = Point(0, 0, 0, 0)
        p2 = Point(1, 1, 0, 0)
        self.assertNotEqual(p1, p2)

    def test_ne2(self):
        p1 = Point(0, 0, 0, 0)
        p2 = Point(1, -1, 0, 0)
        self.assertNotEqual(p1, p2)

    def test_ne_none(self):
        self.assertNotEqual(Point(0, 0, 0, 0), None)
        self.assertNotEqual(None, Point(0, 0, 0, 0))

    def test_add_identity(self):
        p1 = Point(0, 0, 0, 0)
        p2 = Point(1, -1, 0, 0)
        inf = Point(None, None, 0, 0)
        self.assertEqual(p1, p1 + inf)
        self.assertEqual(p1, inf + p1)
        self.assertEqual(p2, p2 + inf)
        self.assertEqual(p2, inf + p2)

    def test_add_inverses(self):
        p1 = Point(1, 1, 0, 0)
        p2 = Point(1, -1, 0, 0)
        inf = Point(None, None, 0, 0)
        self.assertEqual(inf, p1 + p2)

    def test_add_when_x_not_equal_to(self):
        p1 = Point(2, 5, 5, 7)
        p2 = Point(-1, -1, 5, 7)
        self.assertEqual(Point(3, -7, 5, 7), p1 + p2)

    def test_add_p1_eq_p2(self):
        p = Point(-1, -1, 5, 7)
        self.assertEqual(Point(18, 77, 5, 7), p + p)

    def test_add_finite_field1(self):
        prime = 223
        a = FieldElement(0, prime)
        b = FieldElement(7, prime)
        p1 = Point(FieldElement(170, prime),
                   FieldElement(142, prime),
                   a, b)
        p2 = Point(FieldElement(60, prime),
                   FieldElement(139, prime),
                   a, b)
        expedted = Point(FieldElement(220, prime),
                         FieldElement(181, prime),
                         a, b)
        self.assertEqual(expedted, p1 + p2)

    def test_add_finite_field2(self):
        prime = 223
        a = FieldElement(0, prime)
        b = FieldElement(7, prime)
        p1 = Point(FieldElement(47, prime),
                   FieldElement(71, prime),
                   a, b)
        p2 = Point(FieldElement(17, prime),
                   FieldElement(56, prime),
                   a, b)
        expedted = Point(FieldElement(215, prime),
                         FieldElement(68, prime),
                         a, b)
        self.assertEqual(expedted, p1 + p2)

    def test_add_finite_field3(self):
        prime = 223
        a = FieldElement(0, prime)
        b = FieldElement(7, prime)
        p1 = Point(FieldElement(143, prime),
                   FieldElement(98, prime),
                   a, b)
        p2 = Point(FieldElement(76, prime),
                   FieldElement(66, prime),
                   a, b)
        expedted = Point(FieldElement(47, prime),
                         FieldElement(71, prime),
                         a, b)
        self.assertEqual(expedted, p1 + p2)


class FieldElementTest(TestCase):
    def setUp(self):
        self.a = FieldElement(6, 13)
        self.b = FieldElement(7, 13)
        self.c = FieldElement(6, 17)

    def tearDown(self):
        pass

    def test_ne(self):
        self.assertNotEqual(self.a, None)
        self.assertNotEqual(self.a, self.b)
        self.assertNotEqual(self.a, self.c)

    def test_neg(self):
        self.assertEqual(-self.a, FieldElement(7, 13))

    def test_sub(self):
        self.assertEqual(self.a - self.a, FieldElement(0, 13))
        self.assertEqual(self.a - self.b, FieldElement(12, 13))
        self.assertEqual(self.b - self.a, FieldElement(1, 13))

    def test_mul(self):
        self.assertEqual(FieldElement(3, 13), self.a * self.b)

    def test_mul_exc(self):
        with self.assertRaises(TypeError):
            self.a + self.c

    def test_rmul(self):
        expected = FieldElement(2 * self.a.num % self.a.prime, self.a.prime)
        self.assertEqual(expected, 2 * self.a)

    def test_true_div1(self):
        prime = 31
        actual = FieldElement(3, prime) / FieldElement(24, prime)
        self.assertEqual(FieldElement(4, prime), actual)

    def test_true_div2(self):
        prime = 31
        actual = FieldElement(1, prime) / FieldElement(17, prime) ** 3
        self.assertEqual(FieldElement(29, prime), actual)

    def test_true_div3(self):
        prime = 31
        actual = (
            FieldElement(1, prime) /
            FieldElement(4, prime) ** 4 *
            FieldElement(11, prime)
        )
        self.assertEqual(FieldElement(13, prime), actual)


if __name__ == '__main__':
    main()

ecc.py

#!/usr/bin/env python3
class Point:
    def __init__(self, x, y, a, b):
        self.a = a
        self.b = b
        self.x = x
        self.y = y
        if self.x is None and self.y is None:
            return
        if y ** 2 != x ** 3 + a * x + b:
            raise ValueError(f'({x}, {y}) is not on the curve')

    def __eq__(self, other):
        if other is None:
            return False
        return (self.x == other.x and
                self.y == other.y and
                self.a == other.a and
                self.b == other.b)

    def __ne__(self, other):
        return not (self == other)

    def __add__(self, other):
        if self.a != other.a or self.b != other.b:
            raise TypeError(f'Points {self}, {other} are no on the same curve')
        if self.x is None:
            return other
        if other.x is None:
            return self
        if self.x == other.x and self.y != other.y:
            return self.__class__(None, None, self.a, self.b)
        if self.x != other.x:
            s = (other.y - self.y) / (other.x - self.x)
            x = s ** 2 - self.x - other.x
            y = s * (self.x - x) - self.y
            return self.__class__(x, y, self.a, self.b)
        if self == other and self.y == 0 * self.x:
            return self.__class__(None, None, self.a, self.b)
        slope = (3 * self.x ** 2 + self.a) / (2 * self.y)
        x = slope ** 2 - 2 * self.x
        y = slope * (self.x - x) - self.y
        return Point(x, y, self.a, self.b)
        # raise NotImplementedError('Point.__add__')

    def __repr__(self):
        if self.x is None:
            return 'Point(infinity)'
        if isinstance(self.x, FieldElement):
            return f'Point({self.x.num},{self.y.num})' + \
                f'_{self.a.num}_{self.b.num} ' + \
                f'FieldElement({self.x.prime})'
        return f'Point({self.x},{self.y})_{self.a}_{self.b}'


class FieldElement:
    def __init__(self, num: int, prime: int):
        if num < 0 or prime <= num:
            raise ValueError(f'Num {num} not in field range 0 to {prime - 1}')
        self.num = num
        self.prime = prime

    def __repr__(self) -> str:
        return f'FieldElement_{self.prime}({self.num})'

    def __eq__(self, other) -> bool:
        if other is None:
            return False
        return self.num == other.num and self.prime == other.prime

    def __ne__(self, other) -> bool:
        if other is None:
            return True
        return not self == other

    def __neg__(self):
        return self.__class__(-self.num % self.prime, self.prime)

    def __add__(self, other):
        if self.prime != other.prime:
            raise TypeError('Cannot add two numbers in different Fields')
        return self.__class__((self.num + other.num) % self.prime, self.prime)

    def __sub__(self, other):
        if self.prime != other.prime:
            raise TypeError('Cannot subtract two numbers in different Fields')
        return self + (- other)

    def __mul__(self, other):
        if self.prime != other.prime:
            raise TypeError('Cannot multiply two numbers in different Fields')
        return self.__class__((self.num * other.num) % self.prime, self.prime)

    def __rmul__(self, other):
        return self.__class__(other * self.num % self.prime, self.prime)

    def __pow__(self, exponent):
        exponent %= (self.prime - 1)
        return self.__class__(pow(self.num, exponent, self.prime), self.prime)

    def __truediv__(self, other):
        if self.prime != other.prime:
            raise TypeError('Cannot divide two numbers in different Fields')
        num = (self.num *
               pow(other.num, other.prime - 2, other.prime) %
               self.prime)
        prime = self.prime
        return self.__class__(num, prime)
#!/usr/bin/env python3
from ecc import Point, FieldElement


prime = 223
a = FieldElement(0, prime)
b = FieldElement(7, prime)

t = [(2, (192, 105)),
     (2, (143, 98)),
     (2, (47, 71)),
     (4, (47, 71)),
     (8, (47, 71)),
     (21, (47, 71))]

for scalar, (x, y) in t:
    x = FieldElement(x, prime)
    y = FieldElement(y, prime)
    p = Point(x, y, a, b)
    if scalar > 0:
        print(f'{scalar}・{p} = ', end='')
        result = p
        for _ in range(scalar - 1):
            result += p
        print(result)

入出力結果(cmd(コマンドプロンプト)、Terminal、Jupyter(IPython))

C:\Users\...>py ecc_test.py -v
test_mul (__main__.FieldElementTest) ... ok
test_mul_exc (__main__.FieldElementTest) ... ok
test_ne (__main__.FieldElementTest) ... ok
test_neg (__main__.FieldElementTest) ... ok
test_rmul (__main__.FieldElementTest) ... ok
test_sub (__main__.FieldElementTest) ... ok
test_true_div1 (__main__.FieldElementTest) ... ok
test_true_div2 (__main__.FieldElementTest) ... ok
test_true_div3 (__main__.FieldElementTest) ... ok
test_add_finite_field1 (__main__.PointTest) ... ok
test_add_finite_field2 (__main__.PointTest) ... ok
test_add_finite_field3 (__main__.PointTest) ... ok
test_add_identity (__main__.PointTest) ... ok
test_add_inverses (__main__.PointTest) ... ok
test_add_p1_eq_p2 (__main__.PointTest) ... ok
test_add_when_x_not_equal_to (__main__.PointTest) ... ok
test_ne1 (__main__.PointTest) ... ok
test_ne2 (__main__.PointTest) ... ok
test_ne_none (__main__.PointTest) ... ok

----------------------------------------------------------------------
Ran 19 tests in 0.001s

OK

C:\Users\...>py sample4.py
2・Point(192,105)_0_7 FieldElement(223) = Point(49,71)_0_7 FieldElement(223)
2・Point(143,98)_0_7 FieldElement(223) = Point(64,168)_0_7 FieldElement(223)
2・Point(47,71)_0_7 FieldElement(223) = Point(36,111)_0_7 FieldElement(223)
4・Point(47,71)_0_7 FieldElement(223) = Point(194,51)_0_7 FieldElement(223)
8・Point(47,71)_0_7 FieldElement(223) = Point(116,55)_0_7 FieldElement(223)
21・Point(47,71)_0_7 FieldElement(223) = Point(infinity)

C:\Users\...>

0 コメント:

コメントを投稿