Умный способ сериализации / десериализации классов в / из графа Tensorflow
Автоматически привязывать свои поля к графику тензорного потока и обратно
Было бы здорово автоматически привязать поля класса к переменным тензорного потока в графе и восстанавливать их, не возвращая вручную каждую переменную?
Код для этой статьи можно найти здесь, версию для ноутбука jupyter можно найти здесь
Изображение у вас Model
класс
Обычно вы сначала строите свою модель, а затем тренируете ее. После этого вы хотите получить из сохраненного графика старые переменные, не перестраивая всю модель с нуля.
<tf.Variable 'variable:0' shape=(1,) dtype=int32_ref>
Теперь представьте, что мы только что обучили нашу модель и хотим ее сохранить. Обычный образец
Теперь вы хотите выполнить вывод, то есть вернуть свои данные, загрузив сохраненный график. В нашем случае мы хотим, чтобы переменная с именем variable
INFO:tensorflow:Restoring parameters from /tmp/model.ckpt
Теперь мы можем вернуть наш variable
с графика
name: "variable" op: "VariableV2" attr { key: "container" value { s: "" } } attr { key: "dtype" value { type: DT_INT32 } } attr { key: "shape" value { shape { dim { size: 1 } } } } attr { key: "shared_name" value { s: "" } }
Но что, если мы снова захотим использовать наш класс model
? Если мы попробуем сейчас позвонить model.variable
, мы получим None
None
Одно из решений - построить заново всю модель и после этого восстановить график.
INFO:tensorflow:Restoring parameters from /tmp/model.ckpt <tf.Variable 'variable:0' shape=(1,) dtype=int32_ref>
Вы уже понимаете, что это пустая трата времени. Мы можем привязать model.variable
непосредственно к правильному узлу графа,
name: "variable" op: "VariableV2" attr { key: "container" value { s: "" } } attr { key: "dtype" value { type: DT_INT32 } } attr { key: "shape" value { shape { dim { size: 1 } } } } attr { key: "shared_name" value { s: "" } }
Теперь у нас есть очень большая модель с вложенными переменными. Чтобы правильно восстановить указатель каждой переменной в модели, вам необходимо:
- назовите каждую переменную
- вернуть переменные из графика
Было бы здорово, если бы мы могли автоматически получать все переменные, заданные как поля в классе Model?
TFGraphConvertible
Я создал класс под названием TFGraphConvertible
. Вы можете использовать TFGraphConvertible
для автоматической сериализации и десериализации 'класса.
Давайте воссоздадим нашу модель
Он предоставляет два метода: to_graph
и from_graph
.
Сериализовать - to_graph
Чтобы сериализовать класс, вы можете вызвать метод to_graph, который создает словарь имен полей - ›имя переменных тензорного потока. Вам нужно передать fields
аргументов, словарь того поля, которое мы хотим сериализовать. В нашем случае мы можем просто передать их все.
{'variable': 'variable_2:0'}
Он создаст словарь со всеми полями в качестве ключей и соответствующими именами переменных тензорного потока в качестве значений.
Десериализовать - from_graph
Чтобы десериализовать класс, вы можете вызвать метод from_graph, который берет предыдущий созданный словарь и связывает каждое поле класса с правильными переменными тензорного потока.
None <tf.Tensor 'variable_2:0' shape=(1,) dtype=int32_ref>
И теперь ваш model
вернулся!
Полный пример
Посмотрим более интересный пример! Мы собираемся обучить / восстановить модель для набора данных MNIST
Получим набор данных!
Using TensorFlow backend.
Пришло время потренировать это
0.125
0.46875
0.8125
0.953125
0.828125
0.890625
0.796875
0.9375
0.953125
0.921875
Идеально! Сохраним сериализованную модель в памяти
{'x': 'ExpandDims:0',
'y': 'one_hot:0',
'forward_raw': 'dense_1/BiasAdd:0',
'accuracy': 'Mean:0',
'loss': 'Mean_1:0',
'train_step': 'Adam'}
Затем мы сбрасываем график и воссоздаем модель.
INFO:tensorflow:Restoring parameters from /tmp/model.ckpt
Конечно, наших переменных в mnist_model
не существует
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) <ipython-input-21-9def5e0d8f6c> in <module>() ----> 1 mnist_model.accuracy AttributeError: 'MNISTModel' object has no attribute 'accuracy'
Давайте воссоздадим их, вызвав метод from_graph
.
<tf.Tensor 'Mean:0' shape=() dtype=float32>
Теперь mnist_model
готов к работе, давайте посмотрим точность на тестовой выборке.
INFO:tensorflow:Restoring parameters from /tmp/model.ckpt
1.0
Вывод
В этом руководстве мы увидели, как сериализовать класс и связать каждое поле с правильным тензором в графе тензорного потока. Имейте в виду, что вы можете сохранить serialized_model
в формате .json
и загружать его прямо из любого места. Таким образом, вы можете напрямую создать свою модель с помощью объектно-ориентированного программирования и извлекать все переменные внутри них без необходимости их перестраивать.
Спасибо за чтение
Франческо Саверио Цуппичини
Первоначально опубликовано на gist.github.com.