2019年1月9日水曜日

開発環境

問題解決のPythonプログラミング ―数学パズルで鍛えるアルゴリズム的思考 (Srini Devadas (著)、黒川 利明 (翻訳)、オライリージャパン)の18章(メモリは役に立つ)、練習問題(問題3)の解答を求めてみる。

コード

Python 3

#!/usr/bin/env python3
def get_courses_later(course: list, courses: list):
    '''
    >>> get_courses_later([9, 12, 3], courses)
    [[12, 13, 1], [15, 16, 1], [16, 17, 1], [18, 20, 2], [17, 19, 2], [13, 20, 7]]
    >>> get_courses_later([17, 19, 2], courses)
    []
    '''
    _, finish, _ = course
    return [c for c in courses if c[0] >= finish]


def get_weight(courses: list) -> int:
    '''
    >>> get_weight(courses)
    19
    '''
    return sum([weight for _, _, weight in courses])


def recursive_select_memoized(courses: list, memo: dict) -> list:
    '''
    >>> recursive_select_memoized(courses, {})
    [[9, 12, 3], [12, 13, 1], [13, 20, 7]]
    '''
    if len(courses) == 0:
        return []
    weight = 0
    for course in courses:
        courses_later = get_courses_later(course, courses)
        key = repr(courses_later)
        if key in memo:
            courses_tmp = [course] + memo[key]
        else:
            memo[key] = recursive_select_memoized(courses_later, memo)
            courses_tmp = [course] + memo[key]
        weight_tmp = get_weight(courses_tmp)
        if weight_tmp > weight:
            weight = weight_tmp
            selected_courses = courses_tmp
    return selected_courses

# メモ架していない版


def recursive_select(courses: list) -> list:
    '''
    >>> recursive_select(courses)
    [[9, 12, 3], [12, 13, 1], [13, 20, 7]]
    '''
    if len(courses) == 0:
        return []
    weight = 0
    for course in courses:
        courses_later = get_courses_later(course, courses)
        courses_tmp = [course] + recursive_select(courses_later)
        weight_tmp = get_weight(courses_tmp)
        if weight_tmp > weight:
            weight = weight_tmp
            selected_courses = courses_tmp
    return selected_courses


if __name__ == '__main__':
    import doctest

    courses = [[8, 10, 1], [9, 12, 3], [11, 12, 1], [12, 13, 1], [15, 16, 1],
               [16, 17, 1], [18, 20, 2], [17, 19, 2], [13, 20, 7]]
    globs = globals()
    globs.update({'courses': courses})
    doctest.testmod(globs=globs)

    import time

    start = time.time()
    n = 10000
    for _ in range(n):
        result = recursive_select(courses)

    t = time.time() - start
    print(f'メモ化無し: {t}s')

    start = time.time()
    for _ in range(n):
        result = recursive_select_memoized(courses, {})
    t = time.time() - start
    print(f'メモ化有り: {t}s')

入出力結果(Terminal, cmd(コマンドプロンプト), Jupyter(IPython))

$ ./sample3.py -v
Trying:
    get_courses_later([9, 12, 3], courses)
Expecting:
    [[12, 13, 1], [15, 16, 1], [16, 17, 1], [18, 20, 2], [17, 19, 2], [13, 20, 7]]
ok
Trying:
    get_courses_later([17, 19, 2], courses)
Expecting:
    []
ok
Trying:
    get_weight(courses)
Expecting:
    19
ok
Trying:
    recursive_select(courses)
Expecting:
    [[9, 12, 3], [12, 13, 1], [13, 20, 7]]
ok
Trying:
    recursive_select_memoized(courses, {})
Expecting:
    [[9, 12, 3], [12, 13, 1], [13, 20, 7]]
ok
1 items had no tests:
    __main__
4 items passed all tests:
   2 tests in __main__.get_courses_later
   1 tests in __main__.get_weight
   1 tests in __main__.recursive_select
   1 tests in __main__.recursive_select_memoized
5 tests in 5 items.
5 passed and 0 failed.
Test passed.
メモ化無し: 1.786876916885376s
メモ化有り: 0.8609530925750732s
$

0 コメント:

コメントを投稿