У меня проблема, потому что я хотел бы повторить n раз все строки из массива (X, Y) без использования циклов и получить массив (n * X, Y)
import jax.numpy as jnp
arr = jnp.array([[12, 14, 12, 0, 1],
[0, 14, 12, 0, 1],
[0, 0, 12, 0, 1]])
n = 3
result = jnp.array([[12 14 12 0 1],
[12 14 12 0 1],
[12 14 12 0 1],
[0 14 12 0 1],
[0 14 12 0 1],
[0 14 12 0 1],
[0 0 12 0 1],
[0 0 12 0 1],
[0 0 12 0 1]])
Я не нашел встроенного метода для выполнения этой операции, пробовал с jnp.tile, jnp.repeat.
jnp.repeat
arr_r = jnp.repeat(arr, n, axis=1)
Output:
[[12 12 12 14 14 14 12 12 12 0 0 0 1 1 1]
[ 0 0 0 14 14 14 12 12 12 0 0 0 1 1 1]
[ 0 0 0 0 0 0 12 12 12 0 0 0 1 1 1]]
arr_t = jnp.tile(arr, n)
Output:
[[12 14 12 0 1 12 14 12 0 1 12 14 12 0 1]
[ 0 14 12 0 1 0 14 12 0 1 0 14 12 0 1]
[ 0 0 12 0 1 0 0 12 0 1 0 0 12 0 1]]
Может быть, я могу построить массив результатов из array_t...
Вы говорите, что пытались jnp.repeat
, но не объясняете, почему это не делает то, что вы хотите. Я предполагаю, что вы пренебрегаете параметром axis
:
jnp.repeat(arr, n, axis=0)
Я пробовал с этим, но я не получаю такого результата.
Что ты получил? Я не могу установить Jax для проверки, но он должен работать точно так же, как обычный numpy.
Я собираюсь отредактировать вопрос и добавить результаты, которые я действительно пробовал.
Вам нужно использовать axis=0
Да, ты прав. Моя ошибка, я думал, что это по умолчанию, но сглаживание по умолчанию...