Я хочу сохранить лучшую модель, а затем загрузить ее во время теста. Поэтому я использовал следующий метод:
def train():
#training steps …
if acc > best_acc:
best_state = model.state_dict()
best_acc = acc
return best_state
Затем в основной функции я использовал:
model.load_state_dict(best_state)
возобновить модель.
Однако я обнаружил, что best_state всегда совпадает с последним состоянием во время обучения, а не с лучшим состоянием. Кто-нибудь знает причину и как этого избежать?
Кстати, я знаю, что могу использовать torch.save(the_model.state_dict(), PATH)
, а затем загрузить модель с помощью the_model.load_state_dict(torch.load(PATH))
. Однако я не хочу сохранять параметры в файл, так как функции обучения и тестирования находятся в одном файле.