Как нормализовать данные с помощью слоя Lambda?

У меня есть список x как

x = list(np.arange(10))

min = np.min(x)

max = np.max(x)

и я могу создать оконный набор данных, используя следующий метод:

def get_windowed_data(series,window_size):

  dt = tf.data.Dataset.from_tensor_slices(series)
  dt = dt.window(window_size, shift = 1,drop_remainder = True)
  dt = dt.flat_map(lambda window: window.batch(window_size)) # make each window a batch
  dt = dt.map(lambda window: (window[:-1],window[-1:])) # consider the last element as label and the rest as window
  return dt

который дает мне результат вывода. Итак, каждая строка содержит кортеж, первый элемент которого представляет собой список с несколькими элементами, а второй элемент представляет собой список с одним элементом.

[0 1 2 3]   [4]
[1 2 3 4]   [5]
[2 3 4 5]   [6]
[3 4 5 6]   [7]
[4 5 6 7]   [8]
[5 6 7 8]   [9]

Теперь я хочу нормализовать (между 0 и 1) только данные в первом элементе и сохранить метки, как и раньше, и попробовал следующий код:

def get_windowed_data(series,window_size,min,max):

  dt = tf.data.Dataset.from_tensor_slices(series)
  dt = dt.window(window_size, shift = 1,drop_remainder = True)
  #dt = dt.flat_map(lambda window: window.batch(window_size)) # make each window a batch
  dt = dt.flat_map(lambda window: ([ (x-min)/max for x in window[:-1].numpy()],window[-1:])) 
  return dt
  

Так, например, вывод первых двух строк должен быть:

[0.0, 0.1111111111111111, 0.2222222222222222, 0.3333333333333333] [4]
[0.1111111111111111, 0.2222222222222222, 0.3333333333333333, 0.4444444444444444]      [5]

Однако, используя мой код, он жалуется:

   lambda window: ([ (x-min)/max for x in window[:-1].numpy()],window[-1:]))

    TypeError: '_VariantDataset' object is not subscriptable
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
0
36
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

После разделения на два элемента вы можете использовать другую функцию map:

ds = ds.map(lambda wx, wy: ((wx - min) / max, wy))

wx это окно, wy здесь цель. Таким образом, полный пример выглядит следующим образом:

import tensorflow as tf
import numpy as np

x = list(np.arange(10))

min = np.min(x)
max = np.max(x)


def get_windowed_data(series, window_size, min_value, max_value):
    ds = tf.data.Dataset.from_tensor_slices(series)
    ds = ds.window(window_size, shift=1, drop_remainder=True)
    ds = ds.flat_map(lambda w: w.batch(window_size))
    ds = ds.map(lambda w: (w[:-1], w[-1:]))
    ds = ds.map(lambda wx, wy: ((wx - min_value) / max_value, wy))
    return ds


data_normalized = get_windowed_data(x, 5, min, max)

for x, y in data_normalized:
    print(x.numpy(), y.numpy())

Это напечатает:

[0.         0.11111111 0.22222222 0.33333333] [4]
[0.11111111 0.22222222 0.33333333 0.44444444] [5]
[0.22222222 0.33333333 0.44444444 0.55555556] [6]
[0.33333333 0.44444444 0.55555556 0.66666667] [7]
[0.44444444 0.55555556 0.66666667 0.77777778] [8]
[0.55555556 0.66666667 0.77777778 0.88888889] [9]

Другие вопросы по теме