Как оптимизировать бесконечный итератор спирали улама?

Я создал бесконечный итератор, который сопоставляет натуральные числа со всеми точками решетки по спирали, похожей на спираль Улама:

Как оптимизировать бесконечный итератор спирали улама?

Код ниже, я сделал его максимально быстрым и не использовал ни одного условия if:

from itertools import islice, repeat

def ulamish_spiral_gen():
    xc = yc = length = 0
    yield 0, 0
    while True:
        length += 1
        yield from zip(range(xc + 1, (xc := xc + length) + 1, 1), repeat(yc))
        yield from zip(repeat(xc), range(yc + 1, (yc := yc + length) + 1, 1))
        length += 1
        yield from zip(range(xc - 1, (xc := xc - length) - 1, -1), repeat(yc))
        yield from zip(repeat(xc), range(yc - 1, (yc := yc - length) - 1, -1))

def ulamish_spiral(n):
    return list(islice(ulamish_spiral_gen(), n))

Я хочу знать, как запомнить вывод бесконечного итератора, чтобы list(islice(ulamish_spiral_gen(), n)) вызывался только тогда, когда значение n больше, чем последнее n.

Что-то вроде этого:

COMPUTED = []

def ulamish_spiral(n):
    global COMPUTED
    if n > len(COMPUTED):
        COMPUTED = list(islice(ulamish_spiral_gen(), n))
    return COMPUTED[:n]

Это очень просто, но первые len(COMPUTED) члены уже вычислены, нужно вычислить только члены в range(len(COMPUTED), n), но вызов вычисляет все уже вычисленные члены. Поэтому я попытался повторно использовать один и тот же объект генератора и запросить только следующие n - len(COMPUTED) элементы, и мне это удалось.

Но это на самом деле делает код медленнее:

COMPUTED = []
ULAMISH_GEN = ulamish_spiral_gen()
def ulamish_spiral(n):
    if n > (l := len(COMPUTED)):
        COMPUTED.extend(islice(ULAMISH_GEN, n - l))
    return COMPUTED[:n]
In [225]: %timeit COMPUTED.clear(); ULAMISH_GEN = ulamish_spiral_gen(); ulamish_spiral(8192)
928 µs ± 8.96 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [226]: %timeit COMPUTED.clear(); ULAMISH_GEN = ulamish_spiral_gen(); ulamish_spiral(1024); ulamish_spiral(2048); ulamish_spiral(4096); ulamish_spiral(8192)
993 µs ± 18.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [227]: %timeit COMPUTED.clear(); ULAMISH_GEN = ulamish_spiral_gen(); ulamish_spiral(1024); ulamish_spiral(2048); ulamish_spiral(4096); ulamish_spiral(8192); ulamish_spiral(16384)
2.14 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [228]: %timeit COMPUTED.clear(); ULAMISH_GEN = ulamish_spiral_gen(); ulamish_spiral(16384)
2 ms ± 88.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [229]: COMPUTED.clear(); ULAMISH_GEN = ulamish_spiral_gen(); ulamish_spiral(1024); ulamish_spiral(2048); ulamish_spiral(16384) == list(islice(ulamish_spiral_gen(), 16384))
Out[229]: True

Как я могу пропустить уже вычисленные термины и сделать код быстрее?

из-за ваших измерений это не кажется значительно медленнее: вы сравниваете один вызов с несколькими вызовами, которые эквивалентны в параметре n.

RomanPerekhrest 24.07.2023 17:25

сделал это как можно быстрее" - я уверен, что мог бы сделать это быстрее :-)

Kelly Bundy 24.07.2023 18:09

@KellyBundy Если вы думаете, что можете сделать это быстрее, рассмотрите возможность публикации ответа.

Ξένη Γήινος 24.07.2023 18:12
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
1
3
67
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Вы уже рассмотрели "memoize"/"skip". Вот более быстрые бесконечные итераторы. Время для первых 16384 координат (как в вашем тесте):

  1.01 ± 0.00 ms  gen_Kelly_4
  1.03 ± 0.00 ms  gen_Kelly_3
  1.06 ± 0.00 ms  gen_Kelly_2
  1.20 ± 0.00 ms  gen_Kelly_1
  1.57 ± 0.00 ms  gen_original

Python: 3.11.4 (main, Jun 24 2023, 10:18:04) [GCC 13.1.1 20230429]

Вы объединяете все итераторы zip со своим собственным генератором. Мой gen_Kelly_1 вместо этого использует chain.from_iterable для этого.

Вы используете функцию range для создания объектов int снова и снова. Мой gen_Kelly_2 вместо этого сохраняет их в списке и повторно использует.

Мой gen_Kelly_3 далее повторно использует итераторы repeat, а gen_Kelly_4 переворачивает мой список вместо использования обратных итераторов.

Полный код (Попробуйте онлайн!):

from timeit import timeit
from statistics import mean, stdev
from itertools import islice, repeat, cycle, chain, count
import sys

def gen_original():
    xc = yc = length = 0
    yield 0, 0
    while True:
        length += 1
        yield from zip(range(xc + 1, (xc := xc + length) + 1, 1), repeat(yc))
        yield from zip(repeat(xc), range(yc + 1, (yc := yc + length) + 1, 1))
        length += 1
        yield from zip(range(xc - 1, (xc := xc - length) - 1, -1), repeat(yc))
        yield from zip(repeat(xc), range(yc - 1, (yc := yc - length) - 1, -1))


def gen_Kelly_1():
    def parts():
        xc = yc = length = 0
        yield (0, 0),
        while True:
            length += 1
            yield zip(range(xc + 1, (xc := xc + length) + 1, 1), repeat(yc))
            yield zip(repeat(xc), range(yc + 1, (yc := yc + length) + 1, 1))
            length += 1
            yield zip(range(xc - 1, (xc := xc - length) - 1, -1), repeat(yc))
            yield zip(repeat(xc), range(yc - 1, (yc := yc - length) - 1, -1))
    return chain.from_iterable(parts())


def gen_Kelly_2():
    def parts():
        i = 0
        range = []
        while True:
            yield zip(repeat(-i), reversed(range))
            range.insert(0, -i)
            yield zip(range, repeat(-i))
            i += 1
            yield zip(repeat(i), range)
            range.append(i)
            yield zip(reversed(range), repeat(i))
    return chain.from_iterable(parts())


def gen_Kelly_3():
    def parts():
        i = 0
        range = []
        while True:
            rep = repeat(-i)
            yield zip(rep, reversed(range))
            range.insert(0, -i)
            yield zip(range, rep)
            i += 1
            rep = repeat(i)
            yield zip(rep, range)
            range.append(i)
            yield zip(reversed(range), rep)
    return chain.from_iterable(parts())


def gen_Kelly_4():
    def parts():
        i = 0
        range = []
        while True:
            rep = repeat(-i)
            yield zip(rep, range)
            range.append(-i)
            range.reverse()
            yield zip(range, rep)
            i += 1
            rep = repeat(i)
            yield zip(rep, range)
            range.append(i)
            range.reverse()
            yield zip(range, rep)
    return chain.from_iterable(parts())


funcs = gen_original, gen_Kelly_1, gen_Kelly_2, gen_Kelly_3, gen_Kelly_4

n = 16384

# Correctness
expect = list(islice(funcs[0](), n))
for f in funcs[1:]:
    result = list(islice(f(), n))
    assert result == expect

# Speed
times = {f: [] for f in funcs}
def stats(f):
    ts = [t * 1e3 for t in sorted(times[f])[:10]]
    return f'{mean(ts):6.2f} ± {stdev(ts):4.2f} ms '
for _ in range(1000):
    for f in funcs:
        t = timeit(lambda: list(islice(f(), n)), number=1)
        times[f].append(t)
for f in sorted(funcs, key=stats):
    print(stats(f), f.__name__)
print('\nPython:', sys.version)

Я обобщил спираль и сделал вариант с двумя спиралями, каждая из которых пересекает половину решетки и не пересекается друг с другом, и вариант с четырьмя спиралями, каждая из которых пересекает четверть решетки и не пересекается друг с другом. Я разместил свой код на Code Review. Пожалуйста, зайдите туда и посмотрите его.

Ξένη Γήινος 24.07.2023 22:20

Другие вопросы по теме