Pytorch (numpy) расчет ближайших пикселей к точкам

Я пытаюсь решить сложную проблему.

Например, у меня есть пакет предсказанных 2D-изображений (вывод softmax, значение от 0 до 1) с размером: Batch x H x W и наземной достоверностью Batch x H x W.

Pytorch (numpy) расчет ближайших пикселей к точкам

Пиксели светло-серого цвета - это фон со значением 0, а пиксели темно-серого цвета - это передний план со значением 1. Я пытаюсь вычислить координаты центра масс с помощью scipy.ndimage.center_of_mass на каждом наземном изображении. Затем я получаю центральную точку расположения C (красный цвет) для каждой наземной истины. Набор точек C - Batch x 1.

Теперь для каждого пикселя A (желтый цвет) в предсказанных изображениях я хочу получить три пикселя B1, B2, B3 (синий цвет), которые являются ближайшими к A на линии AC (здесь C - соответствующее расположение центра масс на земле).

Я использовал следующий код, чтобы получить три ближайших точки B1, B2, B3.

def connect(ends, m=3):
    d0, d1 = np.abs(np.diff(ends, axis=0))[0]
    if d0 > d1:
        return np.c_[np.linspace(ends[0, 0], ends[1, 0], m + 1, dtype=np.int32),
                 np.round(np.linspace(ends[0, 1], ends[1, 1], m + 1))
                     .astype(np.int32)]
    else:
        return np.c_[np.round(np.linspace(ends[0, 0], ends[1, 0], m + 1))
                     .astype(np.int32),
                 np.linspace(ends[0, 1], ends[1, 1], m + 1, dtype=np.int32)]

Итак, набор точек B - Batch x 3 x H x W.

Затем я хочу вычислить вот так: |Value(A)-Value(B1)|+|Value(A)-Value(B2)|+|Value(A)-Value(B3)|. Размер результата должен быть Batch x H x W.

Есть ли какие-нибудь уловки векторизации, которые можно использовать для обновления значения каждого пикселя в предсказанных изображениях? Или это можно решить с помощью функций pytorch? Мне нужно найти способ обновить все изображение. Прогнозируемое изображение - это выходной сигнал softmax. Я не могу использовать цикл for для вычисления каждого отдельного значения, поскольку оно станет недифференцируемым. Большое спасибо.

Проверьте это: stackoverflow.com/a/47704298/3540982

Matin 16.08.2018 11:42

Возможно, вы захотите попробовать создать минимальный воспроизводимый пример, у вас слишком много независимых частей в вашем вопросе, чтобы дать хороший ответ.

Daniel F 16.08.2018 12:03

@DanielF Привет, я обновил вопрос. Теперь нужно только обновить значение пикселя A. Можете ли вы помочь решить эту проблему? Спасибо.

N.Z 16.08.2018 13:12

@Matin Спасибо, Матин. Я могу вычислить точки B, но не знаю, как вычислить | Значение (A) -Значение (B1) | + | Значение (A) -Значение (B2) | + | Значение (A) -Значение (B3‌) | пока не используется цикл for.

N.Z 16.08.2018 13:39
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
4
827
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Как предлагает @Matin, вы можете рассмотреть Алгоритм Брезенхема, чтобы получить свои очки на линии AC.

Упрощенная реализация PyTorch может быть следующей (непосредственно адаптированной из псевдокода здесь; может быть оптимизирована):

import torch

def get_points_from_low(x0, y0, x1, y1, num_points=3):
    dx = x1 - x0
    dy = y1 - y0
    xi = torch.sign(dx)
    yi = torch.sign(dy)
    dy = dy * yi
    D = 2 * dy - dx

    y = y0
    x = x0

    points = []
    for n in range(num_points):
        x = x + xi
        is_D_gt_0 = (D > 0).long()
        y = y + is_D_gt_0 * yi
        D = D + 2 * dy - is_D_gt_0 * 2 * dx

        points.append(torch.stack((x, y), dim=-1))

    return torch.stack(points, dim=len(x0.shape))

def get_points_from_high(x0, y0, x1, y1, num_points=3):
    dx = x1 - x0
    dy = y1 - y0
    xi = torch.sign(dx)
    yi = torch.sign(dy)
    dx = dx * xi
    D = 2 * dx - dy

    y = y0
    x = x0

    points = []
    for n in range(num_points):
        y = y + yi
        is_D_gt_0 = (D > 0).long()
        x = x + is_D_gt_0 * xi
        D = D + 2 * dx - is_D_gt_0 * 2 * dy

        points.append(torch.stack((x, y), dim=-1))

    return torch.stack(points, dim=len(x0.shape))

def get_points_from(x0, y0, x1, y1, num_points=3):
    is_dy_lt_dx = (torch.abs(y1 - y0) < torch.abs(x1 - x0)).long()
    is_x0_gt_x1 = (x0 > x1).long()
    is_y0_gt_y1 = (y0 > y1).long()

    sign = 1 - 2 * is_x0_gt_x1
    x0_comp, x1_comp, y0_comp, y1_comp = x0 * sign, x1 * sign, y0 * sign, y1 * sign
    points_low = get_points_from_low(x0_comp, y0_comp, x1_comp, y1_comp, num_points=num_points)
    points_low *= sign.view(-1, 1, 1).expand_as(points_low)

    sign = 1 - 2 * is_y0_gt_y1
    x0_comp, x1_comp, y0_comp, y1_comp = x0 * sign, x1 * sign, y0 * sign, y1 * sign
    points_high = get_points_from_high(x0_comp, y0_comp, x1_comp, y1_comp, num_points=num_points) * sign
    points_high *= sign.view(-1, 1, 1).expand_as(points_high)

    is_dy_lt_dx = is_dy_lt_dx.view(-1, 1, 1).expand(-1, num_points, 2)
    points = points_low * is_dy_lt_dx + points_high * (1 - is_dy_lt_dx)

    return points

# Inputs:
# (@todo: extend A to cover all points in maps):
A = torch.LongTensor([[0, 1], [8, 6]])
C = torch.LongTensor([[6, 4], [2, 3]])
num_points = 3

# Getting points between A and C:
# (@todo: what if there's less than `num_points` between A-C?)
Bs = get_points_from(A[:, 0], A[:, 1], C[:, 0], C[:, 1], num_points=num_points)
print(Bs)
# tensor([[[1, 1],
#          [2, 2],
#          [3, 2]],
#         [[7, 6],
#          [6, 5],
#          [5, 5]]])

Когда у вас есть свои очки, вы можете получить их «значения» (Value(A), Value(B1) и т.д.), используя torch.index_select() (обратите внимание, что на данный момент этот метод принимает только одномерные индексы, поэтому вам нужно распутать ваши данные). Все вместе, это будет выглядеть примерно так (расширение A от формы (Batch, 2) до (Batch, H, W, 2) осталось для упражнений ...)

# Inputs:
# (@todo: extend A to cover all points in maps):
A = torch.LongTensor([[0, 1], [8, 6]])
C = torch.LongTensor([[6, 4], [2, 3]])
batch_size = A.shape[0]
num_points = 3
map_size = (9, 9)
map_num_elements = map_size[0] * map_size[1]
map_values = torch.stack((torch.arange(0, map_num_elements).view(*map_size),
                          torch.arange(0, -map_num_elements, -1).view(*map_size)))

# Getting points between A and C:
# (@todo: what if there's less than `num_points` between A-C?)
Bs = get_points_from(A[:, 0], A[:, 1], C[:, 0], C[:, 1], num_points=num_points)

# Get map values in positions A:
A_unravel = torch.arange(0, batch_size) * map_num_elements
A_unravel = A_unravel + A[:, 0] * map_size[1] + A[:, 1]
values_A = torch.index_select(map_values.view(-1), dim=0, index=A_unravel)
print(values_A)
# tensor([ 1, -4])

# Get map values in positions A:
A_unravel = torch.arange(0, batch_size) * map_num_elements
A_unravel = A_unravel + A[:, 0] * map_size[1] + A[:, 1]
values_A = torch.index_select(map_values.view(-1), dim=0, index=A_unravel)
print(values_A)
# tensor([  1, -78])

# Get map values in positions B:
Bs_flatten = Bs.view(-1, 2)
Bs_unravel = (torch.arange(0, batch_size)
              .unsqueeze(1)
              .repeat(1, num_points)
              .view(num_points * batch_size) * map_num_elements)
Bs_unravel = Bs_unravel + Bs_flatten[:, 0] * map_size[1] + Bs_flatten[:, 1]
values_B = torch.index_select(map_values.view(-1), dim=0, index=Bs_unravel)
values_B = values_B.view(batch_size, num_points)
print(values_B)
# tensor([[ 10,  20,  29],
#         [-69, -59, -50]])

# Compute result:
res = torch.abs(values_A.unsqueeze(-1).expand_as(values_B) - values_B)
print(res)
# tensor([[ 9, 19, 28],
#         [ 9, 19, 28]])
res = torch.sum(res, dim=1)
print(res)
# tensor([56, 56])

Привет, Алдрим, спасибо за ответ. Когда я получаю координаты трех точек для каждого пикселя A, знаете ли вы, как я могу вычислить | Value (A) -Value (B1) | + | Value (A) -Value (B2) | + | Value (A) - Значение (B3‌) |? Значения каждого пикселя A и соответствующих ему трех пикселей B поступают из одного и того же вывода softmax. И вывод softmax - это переменная в Pytorch с require_grad = True, поэтому мне нужно вычислить это со всем изображением.

N.Z 16.08.2018 13:30

Я обновил свой ответ, чтобы дать дальнейшие указания, хотя ваша проблема, возможно, слишком обширна, чтобы ее можно было исследовать в одном сообщении ...

benjaminplanche 16.08.2018 14:52

Нашел ошибку. Если координаты A больше, чем координаты C, возвращаемые результаты должны быть обратными. Например, A - это [8,6], C - это [2, 3]. Вы можете помочь исправить это? Спасибо.

N.Z 17.08.2018 16:17

Действительно, алгоритм Брезенхема в таких случаях переключает A и C, не обращая внимания на порядок / направление точек. Я исправил предложенный метод. Теперь он работает, например. с A = [8,6] и C = [2, 3] (см. ответ), хотя все остальные случаи я не проверял ...

benjaminplanche 17.08.2018 17:12

Приведенный выше код уже работает для пикселей N (A формы (N, 2)). Вы можете использовать функции arange / meshgrid для получения всех ваших образов A (форма (H, W, 2)), а затем сгладить тензор так, чтобы N = H * W.

benjaminplanche 18.08.2018 19:58

Да, я использовал arange и meshgrid для одного изображения как (H, W, 2). Однако есть еще один параметр - размер партии. Пиксели A должны быть (Batch, H, W, 2).

N.Z 18.08.2018 21:48

Предполагая, что все A одинаковы для каждого изображения в пакете, это будет означать просто разбиение / расширение вашего тензора, чтобы достичь (B, H, W, 2), перед сглаживанием до (B*H*W, 2) для вычисления точек. Точно так же, предполагая, что у вас уже есть C формы (B, 2), вы можете просто расширить его до (B, H, W, 2) перед сведением к той же форме, что и A.

benjaminplanche 19.08.2018 12:46

Я закончил. Спасибо!

N.Z 22.08.2018 10:12

Другие вопросы по теме