Умный способ сериализации / десериализации классов в / из графа Tensorflow

Автоматически привязывать свои поля к графику тензорного потока и обратно

Было бы здорово автоматически привязать поля класса к переменным тензорного потока в графе и восстанавливать их, не возвращая вручную каждую переменную?

Код для этой статьи можно найти здесь, версию для ноутбука jupyter можно найти здесь

Изображение у вас Model класс

Обычно вы сначала строите свою модель, а затем тренируете ее. После этого вы хотите получить из сохраненного графика старые переменные, не перестраивая всю модель с нуля.

<tf.Variable 'variable:0' shape=(1,) dtype=int32_ref>

Теперь представьте, что мы только что обучили нашу модель и хотим ее сохранить. Обычный образец

Теперь вы хотите выполнить вывод, то есть вернуть свои данные, загрузив сохраненный график. В нашем случае мы хотим, чтобы переменная с именем variable

INFO:tensorflow:Restoring parameters from /tmp/model.ckpt

Теперь мы можем вернуть наш variable с графика

name: "variable" op: "VariableV2" attr { key: "container" value { s: "" } } attr { key: "dtype" value { type: DT_INT32 } } attr { key: "shape" value { shape { dim { size: 1 } } } } attr { key: "shared_name" value { s: "" } }

Но что, если мы снова захотим использовать наш класс model? Если мы попробуем сейчас позвонить model.variable, мы получим None

None

Одно из решений - построить заново всю модель и после этого восстановить график.

INFO:tensorflow:Restoring parameters from /tmp/model.ckpt <tf.Variable 'variable:0' shape=(1,) dtype=int32_ref>

Вы уже понимаете, что это пустая трата времени. Мы можем привязать model.variable непосредственно к правильному узлу графа,

name: "variable" op: "VariableV2" attr { key: "container" value { s: "" } } attr { key: "dtype" value { type: DT_INT32 } } attr { key: "shape" value { shape { dim { size: 1 } } } } attr { key: "shared_name" value { s: "" } }

Теперь у нас есть очень большая модель с вложенными переменными. Чтобы правильно восстановить указатель каждой переменной в модели, вам необходимо:

  • назовите каждую переменную
  • вернуть переменные из графика

Было бы здорово, если бы мы могли автоматически получать все переменные, заданные как поля в классе Model?

TFGraphConvertible

Я создал класс под названием TFGraphConvertible. Вы можете использовать TFGraphConvertible для автоматической сериализации и десериализации 'класса.

Давайте воссоздадим нашу модель

Он предоставляет два метода: to_graph и from_graph.

Сериализовать - to_graph

Чтобы сериализовать класс, вы можете вызвать метод to_graph, который создает словарь имен полей - ›имя переменных тензорного потока. Вам нужно передать fields аргументов, словарь того поля, которое мы хотим сериализовать. В нашем случае мы можем просто передать их все.

{'variable': 'variable_2:0'}

Он создаст словарь со всеми полями в качестве ключей и соответствующими именами переменных тензорного потока в качестве значений.

Десериализовать - from_graph

Чтобы десериализовать класс, вы можете вызвать метод from_graph, который берет предыдущий созданный словарь и связывает каждое поле класса с правильными переменными тензорного потока.

None <tf.Tensor 'variable_2:0' shape=(1,) dtype=int32_ref>

И теперь ваш model вернулся!

Полный пример

Посмотрим более интересный пример! Мы собираемся обучить / восстановить модель для набора данных MNIST

Получим набор данных!

Using TensorFlow backend.

Пришло время потренировать это

0.125
0.46875
0.8125
0.953125
0.828125
0.890625
0.796875
0.9375
0.953125
0.921875

Идеально! Сохраним сериализованную модель в памяти

{'x': 'ExpandDims:0', 
'y': 'one_hot:0', 
'forward_raw': 'dense_1/BiasAdd:0', 
'accuracy': 'Mean:0', 
'loss': 'Mean_1:0', 
'train_step': 'Adam'}

Затем мы сбрасываем график и воссоздаем модель.

INFO:tensorflow:Restoring parameters from /tmp/model.ckpt

Конечно, наших переменных в mnist_model не существует

--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) <ipython-input-21-9def5e0d8f6c> in <module>() ----> 1 mnist_model.accuracy AttributeError: 'MNISTModel' object has no attribute 'accuracy'

Давайте воссоздадим их, вызвав метод from_graph.

<tf.Tensor 'Mean:0' shape=() dtype=float32>

Теперь mnist_model готов к работе, давайте посмотрим точность на тестовой выборке.

INFO:tensorflow:Restoring parameters from /tmp/model.ckpt 
1.0

Вывод

В этом руководстве мы увидели, как сериализовать класс и связать каждое поле с правильным тензором в графе тензорного потока. Имейте в виду, что вы можете сохранить serialized_model в формате .json и загружать его прямо из любого места. Таким образом, вы можете напрямую создать свою модель с помощью объектно-ориентированного программирования и извлекать все переменные внутри них без необходимости их перестраивать.

Спасибо за чтение

Франческо Саверио Цуппичини

Первоначально опубликовано на gist.github.com.