WedX - журнал о программировании и компьютерных науках

best_state меняется вместе с моделью во время обучения в pytorch

Я хочу сохранить лучшую модель, а затем загрузить ее во время теста. Поэтому я использовал следующий метод:

def train():  
    #training steps …  
    if acc > best_acc:  
        best_state = model.state_dict()  
        best_acc = acc
    return best_state  

Затем в основной функции я использовал:

model.load_state_dict(best_state)  

возобновить модель.

Однако я обнаружил, что best_state всегда совпадает с последним состоянием во время обучения, а не с лучшим состоянием. Кто-нибудь знает причину и как этого избежать?

Кстати, я знаю, что могу использовать torch.save(the_model.state_dict(), PATH), а затем загрузить модель с помощью the_model.load_state_dict(torch.load(PATH)). Однако я не хочу сохранять параметры в файл, так как функции обучения и тестирования находятся в одном файле.


  • версия 1.1.0, линукс, графический процессор 10.06.2019

Ответы:


1

model.state_dict() is OrderedDict

from collections import OrderedDict

Ты можешь использовать:

from copy import deepcopy

Решить проблему

Вместо:

best_state = model.state_dict() 

Вы должны использовать:

best_state = copy.deepcopy(model.state_dict())

Глубокая (не поверхностная) копия заставляет изменяемый экземпляр OrderedDict не изменять best_state по ходу дела.

Вы можете проверить мой другой ответ о сохранении диктора состояния в PyTorch.

12.06.2019

2

Когда вы сохраняете состояние модели, вы должны сохранить в сети следующие вещи:

1) Состояние оптимизатора и 2) Состояние модели

Вы можете определить один метод в своей модели класса следующим образом

def save_state(state,filename):
    torch.save(state,filename)

''' Когда вы сохраняете состояние, сделайте следующее: '''

Model model //for example  
model.save_state({'state_dict':model.state_dict(), 'optimizer': optimizer.state_dict()}) 

Сохраненная модель будет храниться как model.pth.tar (для примера)

Теперь во время загрузки выполните следующие шаги,

checkpoint = torch.load('model.pth.tar')         

model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

Надеюсь, что это поможет вам.

12.06.2019
  • Этот ответ хорошо подходит для этот вопрос, но не идеально подходит для этого. 12.06.2019
  • Спасибо, я знаю, что могу сохранить состояние в файл. Но я предпочитаю ответ Саурава 13.06.2019
  • @ Lei_Bai Очень рад помочь вам немного. Надеюсь, я буду вам полезен и в будущем. 13.06.2019
  • Новые материалы

    Как создать диаграмму градиентной кисти с помощью D3.js
    Резюме: Из этого туториала Вы узнаете, как добавить градиентную кисть к диаграмме с областями в D3.js. Мы добавим градиент к значениям SVG и применим градиент в качестве заливки к диаграмме с..

    Я хотел выучить язык программирования MVC4, но не мог выучить его раньше, потому что это выглядит сложно…
    Просто начните и учитесь самостоятельно Я хотел выучить язык программирования MVC4, но не мог выучить его раньше, потому что он кажется мне сложным, и я бросил его. Это в основном инструмент..

    Лицензии с открытым исходным кодом: руководство для разработчиков и создателей
    В динамичном мире разработки программного обеспечения открытый исходный код стал мощной парадигмой, способствующей сотрудничеству, инновациям и прогрессу, движимому сообществом. В основе..

    Объяснение документов 02: BERT
    BERT представил двухступенчатую структуру обучения: предварительное обучение и тонкая настройка. Во время предварительного обучения модель обучается на неразмеченных данных с помощью..

    Как проанализировать работу вашего классификатора?
    Не всегда просто знать, какие показатели использовать С развитием глубокого обучения все больше и больше людей учатся обучать свой первый классификатор. Но как только вы закончите..

    Работа с цепями Маркова, часть 4 (Машинное обучение)
    Нелинейные цепи Маркова с агрегатором и их приложения (arXiv) Автор : Бар Лайт Аннотация: Изучаются свойства подкласса случайных процессов, называемых дискретными нелинейными цепями Маркова..

    Crazy Laravel Livewire упростил мне создание электронной коммерции (панель администратора и API) [Часть 3]
    Как вы сегодня, ребята? В этой части мы создадим CRUD для данных о продукте. Думаю, в этой части я не буду слишком много делиться теорией, но чаще буду делиться своим кодом. Потому что..


    Для любых предложений по сайту: [email protected]