Зачем использовать сложную модель, когда простые трюки?
Вступление
Как мы знаем, машинное обучение повсеместно встречается в нашей повседневной жизни. От рекомендаций по продуктам на Amazon, таргетированной рекламы и предложений о том, что смотреть, до забавного Instagram фильтры.
Если с ними что-то пойдет не так, это, вероятно, не испортит вам жизнь. Может быть, у вас не получится это идеальное селфи, а может быть, компаниям придется больше тратить на рекламу.
Как насчет распознавания лиц в правоохранительных органах? Заявки на получение ссуды или ипотеки? Беспилотные автомобили?
В этих приложениях с высоким уровнем риска мы не можем действовать вслепую. Нам нужно уметь анализировать нашу модель, нам нужно будет уметь понимать и объяснять нашу модель, прежде чем она будет приближаться к производственной системе.
Объяснимое машинное обучение необходимо, когда мы принимаем решения о людях, которые могут негативно повлиять на их жизнь, например, ипотека или кредитная оценка.
Использование объяснимых моделей также обеспечивает более эффективную отладку, а также лучшее понимание справедливости, конфиденциальности, причинности и большего доверия к модели.
Оглавление
- Типы объяснимости
- Объясняемые модели
- Обобщенные линейные модели
- Деревья решений
- Обобщенные аддитивные модели
- Повышение монотонного градиента
- TabNet - "Сравнение"
- "Заключение"
Типы объяснимости
- Постфактум: это когда мы объясняем модель после того, как были сделаны прогнозы или после того, как модель была обучена. Это здорово, потому что эти методы позволяют нам объяснять очень сложные модели. Однако эти методы можно обмануть, они ошибочны при определенных условиях и требуют дополнительного уровня сложности для генерации объяснений. Распространенными примерами являются пакеты Python SHAP и LIME.
- Присущие: некоторые модели можно объяснить "из коробки", без дополнительных моделей или библиотек, размещенных поверх. Как правило, они проще и в некоторых случаях могут иметь меньшую предсказательную силу - хотя некоторые исследователи утверждают, что это не всегда так!
В этой статье будут рассмотрены модели, объяснимые по своей сути.
Естественно объяснимые модели
Lipton, 2016 (https://arxiv.org/abs/1606.03490) определяет объяснимость модели, используя 3 критерия:
- Симулируемость. Может ли человек пройти все этапы модели за «разумное» время?
- Разложимость. Можно ли разбить все аспекты модели, включая ее характеристики, параметры и веса?
- Алгоритмическая прозрачность: можем ли мы понять, как модель будет реагировать на невидимые данные.
Возможно, вам уже известны модели, соответствующие этому критерию; Деревья решений и Логистическая регрессия. Если в модели не используется слишком много функций, обе соответствуют всем трем критериям.
Даже если вы используете объяснимую модель, использование слишком большого количества функций или высокотехнологичных функций может снизить объяснимость вашей модели. Нам также необходимо, чтобы наши данные и предварительная обработка были объяснимыми.
Однако есть несколько менее известных объяснимых моделей, которые полезно иметь в своем наборе инструментов. Эти методы также могут иметь большую предсказательную силу, чем деревья решений и логистическая регрессия, при этом сохраняя некоторый уровень объяснимости, позволяя вам сбалансировать объяснимость с точность в вашем следующем проекте.
Обобщенные линейные модели (GLM)
Что это такое и как работает?
GLM - это на самом деле просто причудливый способ описания линейной или логистической регрессии. Ключевой концепцией здесь является наличие других линейных моделей, которые делают их более гибкими.
GLM состоит из 3 компонентов:
- Линейный предиктор. Это просто уравнение регрессии - линейная комбинация переменных и некоторой предикторной переменной.
2. Функция связи: связывает нашу линейную комбинацию переменных с распределением вероятностей. В линейной регрессии это просто функция ссылки идентичности.
3. Распределение вероятностей: так генерируется наша переменная y. В линейной регрессии это нормальное распределение.
Меняя их, мы можем получить разные модели. Использование функции связи Logit с распределением Бернулли дает нам логистическую регрессию.
Менее известной версией является регрессия Пуассона, в которой используется распределение Пуассона. Предполагается, что наша линейная комбинация переменных связана с логарифмом y.
Чтобы узнать больше о логистической регрессии, ознакомьтесь с моей статьей здесь:
Почему это объяснимо?
- Потому что это всего лишь простая математика
На графике ниже показано простое уравнение линейной регрессии только для одной переменной. Мы знаем, что можем вычислить y, нашу цель, для любого значения x, используя линию. Мы также знаем, как рассчитывается линия (минимизируя ошибки).
Мы можем масштабировать это до любого количества переменных, и математические расчеты по-прежнему остаются в силе. Мы просто складываем наши термины вместе, чтобы получить результат.
Логистическая регрессия или регрессия Пуассона более сложны, но основная концепция остается верной. Мы суммируем линейную комбинацию переменных.
- Коэффициенты что-то значат
В линейной регрессии наши коэффициенты даны в терминах целевой переменной. Это позволяет нам делать такие заявления, как:
Увеличение площади на 50 м² увеличит стоимость нашего дома на 1000 фунтов стерлингов.
В логистической регрессии это логарифмические шансы, которые мы можем преобразовать в вероятности.
Тот факт, что эти коэффициенты могут быть преобразованы в понятные человеку утверждения, действительно увеличивает объяснимость линейных моделей.
- Взаимодействия должны быть явно запрограммированы
Модели на основе деревьев будут определять взаимодействия между переменными, которые могут увеличивать сложность. Многие методы объяснимости плохо охватывают взаимодействия, и их бывает сложно разработать. Например
Увеличение площади на 50 м² увеличит стоимость нашего дома на 1000 фунтов стерлингов, пока мы не дойдем до 200 м², при этом цена нашего дома увеличится на 1000 фунтов стерлингов за 50 м² * количество спален.
Этот тип сложной взаимосвязи будет обнаружен в линейной модели только в том случае, если он вычислен и запрограммирован в явном во время проектирования функций.
Реализация
Деревья решений
Что это такое и как работает?
Большинству людей в какой-то момент своей жизни следовало увидеть Дерево решений! Алгоритмическая версия использует простую математику для создания «оптимального» дерева решений, то есть дерева, которое лучше всего разделяет наши данные.
Чтобы понять, как это работает, загляните сюда:
Почему это объяснимо?
Самое замечательное в деревьях решений то, что мы можем буквально просто извлечь все дерево и проследить, почему модель сделала прогноз для любой выборки в наборе данных.
Когда деревья становятся очень большими (max_depth = 7+), людям становится трудно следить за ними из-за экспоненциального роста количества листьев. Однако на этом этапе мы все еще можем написать базовый код, чтобы выделить путь, по которому наши данные достигли своего прогноза, а также проверить, как модель будет реагировать на невидимые данные.
Я считаю, что это одна из наиболее объяснимых моделей из-за очень небольшого количества математики и концепции, которая встречается в других странах мира.
Реализация
Как обычно, для этой модели я предпочел бы scikit-learn.
Для построения дерева существует множество различных вариантов, вот хороший список.
Обобщенные аддитивные модели (GAM)
Что это такое и как работает?
Обобщенные аддитивные модели (GAM) - это расширение GLM, которое снимает одно из основных ограничений; Теперь мы можем моделировать нелинейные отношения в наших данных.
GAM делают это с помощью серии сложных функций, известных как сплайны для оценки каждой переменной. Мы по-прежнему суммируем наши переменные, но сплайны означают, что переменные и целевое значение могут иметь нелинейную связь.
Чтобы узнать больше о GAM, загляните здесь:
Почему это объяснимо?
Несомненно, GAM менее объяснимы, чем логистическая или линейная регрессия. Модель намного сложнее, как и лежащая в ее основе математика. Однако они по-прежнему поддерживают уровень объяснимости, что является большим компромиссом, если учесть их гибкость.
- Комбинация нелинейных переменных?
Целевая переменная по-прежнему представляет собой просто сумму всех других переменных с некоторым весом, теперь у нас есть сложная функция, моделирующая каждую переменную. Мы по-прежнему можем извлекать и визуализировать эту функцию для каждой переменной, и большинство пакетов GAM используют графики частичной зависимости, чтобы сделать это для всех функций.
Взаимодействия необходимо программировать вручную, что снижает сложность. Мы также можем в общих чертах понять, как модель будет вести себя с невидимыми данными, поскольку мы знаем сплайн-функции для каждой функции.
Реализация
Из моих исследований кажется, что пакет mgcv в R лучше всего подходит для GAM. Однако я предпочитаю Python; два лучших варианта - Statsmodels и PyGAM.
Microsoft Research открыла исходный код своего пакета InterpretML, который включает в себя их Explainable Boosting Machine, которую они называют GAM 2.0, поскольку она использует GAM с условиями автоматического взаимодействия и градиентным усилением для обеспечения объяснимости, повышения производительности и уменьшения потребности специалистов по анализу данных в доступе к глубоко в модели.
Монотонное усиление градиента
Что это такое и как работает?
Модели повышения градиента считаются лучшими в своем классе для табличных данных, но из-за характера повышения они не поддаются интерпретации. Эти модели могут использовать сотни отдельных деревьев с разным весом. Они также склонны сами вырабатывать условия взаимодействия, в отношении которых у нас мало прозрачности. Обычно для повышения интерпретируемости этих моделей используют SHAP или LIME.
Монотонная связь - это когда цель и объект имеют линейную связь, например:
- Ваш ИМТ увеличивается, а риск сердечного приступа увеличивается.
- Ваш кредитный рейтинг уменьшается, а вероятность получения ссуды уменьшается.
- Количество дождя увеличивается, а количество арендованных велосипедов уменьшается.
Линейные модели полностью монотонны, но поскольку повышение градиента включает взаимодействия и может моделировать нелинейные отношения, они обычно не создают монотонных отношений.
XGBoost, LightGBM и Catboost имеют простой гиперпараметр, который заставляет переменную иметь положительную или отрицательную монотонную связь.
Почему это объяснимо?
Использование монотонных отношений означает, что мы можем использовать утверждения, подобные приведенным выше, для объяснения нашей модели. Это заставляет модель соответствовать критериям алгоритмической прозрачности, поскольку эта взаимосвязь фиксирована. Мы также можем встроить в модель некоторые из наших реальных знаний, которые сделают ее более понятной для деловых людей и внесут больше изменений в производство.
Реализация
- XGBoost
В XGBoost мы указываем параметр monotone_constraints как строковый кортеж (скобки внутри речевых знаков) с одним числом для каждой функции в нашем наборе данных, поэтому «(1,0, -1)» обозначает функции 1, 2 и 3 . 1 - положительная монотонная связь, -1 - отрицательная, а 0 - никакая связь.
import xgboost as xgb params = {'monotone_constraints':'(1,0,-1)'} model = xgb.train(params, X_train, num_boost_round = 1000, early_stopping_rounds = 10)
- LightGBM
LightGBM в целом аналогичен XGBoost, но мы должны передавать наши функции в виде списка, а не строки / кортежа. LightGBM также предлагает дополнительный параметр метода. Используя это, мы можем выбрать, насколько сильно модель будет пытаться придерживаться ограничения.
basic
, самый простой метод ограничения монотонности. Это совсем не замедляет работу библиотеки, но чрезмерно ограничивает прогнозы.
intermediate
, более продвинутый метод, который может очень немного замедлить работу библиотеки. Однако этот метод гораздо менее ограничивает, чем базовый метод, и должен значительно улучшить результаты.
advanced
, еще более продвинутый метод, который может замедлить работу библиотеки. Однако этот метод еще менее ограничивает, чем промежуточный метод, и должен снова значительно улучшить результаты.
import lightgbm as lgb
params = {'monotone_constraints': [-1, 0, 1],
'monotone_constraints_method
':'basic'}
model = lgb.train(params,
X_train,
num_round = 1000,
early_stopping_rounds = 10)
- Catboost
Catboost очень похож на другие, но предлагает большую гибкость, поскольку мы можем передавать ограничения в виде массива, использовать нарезку и явно указывать функцию.
Параметр называется monotone_constraints
, и вы можете ознакомиться с документацией по Catboost здесь.
TabNet
Что это такое и как работает?
TabNet был опубликован исследователями Google Brain в 2019 году. Традиционно, подходы нейронных сетей существенно не улучшали повышение градиента при работе с табличными данными. Однако Tabnet смог превзойти ведущие древовидные модели по множеству тестов. Это значительно более объяснимо, чем модели усиленного дерева, поскольку у него есть встроенная объяснимость. Его также можно использовать без предварительной обработки функций.
Моя статья о TabNet описывает эту модель более подробно. Посмотрите здесь:
Почему это объяснимо?
TabNet использует механизм последовательного внимания для выбора наиболее важных функций, это влияет на «маску», которая скрывает наименее важные функции. Мы можем использовать веса этой маски, чтобы понять, какие функции используются чаще, чем другие, что, по сути, позволяет нам понять, какие функции модель использует для своих прогнозов.
Выбор функций выполняется на уровне строки набора данных, что означает, что мы можем фактически исследовать, какие функции были выбраны для одного прогноза. Количество масок - гиперпараметр модели.
Реализация
Лучше всего использовать TabNet с реализацией PyTorch от Dreamquark. Он использует оболочку стиля scikit-learn и совместим с графическим процессором. Dreamquark также предоставляет несколько действительно отличных ноутбуков, которые прекрасно демонстрируют, как реализовать TabNet, а также работают над проверкой заявлений авторов о точности моделей на определенных тестах.
Классификация
Регрессия
Сравнение моделей
Вернемся к трем критериям Липтона и применим их к каждой модели. Напоминаем, что критерии…
- Симулируемость. Может ли человек пройти все этапы модели за «разумное» время?
- Разложимость. Можно ли разбить все аспекты модели, включая ее характеристики, параметры и веса?
- Алгоритмическая прозрачность: можем ли мы понять, как модель будет реагировать на невидимые данные.
Мы также хотим рассмотреть локальную объяснимость; в какой степени модель может сделать единичный прогноз с точки зрения того, какие функции она использовала и в какой степени она использовала каждую функцию для принятия своего решения.
Я суммировал это в таблице ниже, присвоив каждой модели низкий, средний или высокий балл по каждому критерию. Это не точная наука, но вы можете рассматривать каждый результат относительно линейной регрессии.
Помните, что любая разработка функций может полностью сбросить эти оценки. Создание сложных условий взаимодействия, математических преобразований или функций, полученных из нейронной сети, может повысить точность, но, безусловно, снизит объяснимость. Дополнительные функции также могут снизить объяснимость.
Выводы
Ажиотаж и дискуссии вокруг объяснимости моделей только нарастают. Поскольку ИИ используется в языковом моделировании, распознавании лиц и беспилотных автомобилях, наличие моделей, которые могут обосновать свое решение, как никогда важно.
Попробуйте одну из этих объяснимых моделей в своем следующем проекте и дайте мне знать, как это у вас получится.