Авторы Zach Witzel и Xiluo He в рамках финального проекта Stanford CS 224W.
Описание задачи и мотивация:
Авиационная отрасль является важнейшим компонентом современного общества, объединяющим людей и предприятия по всему миру. Однако инфраструктура авиакомпаний — хрупкая система, чреватая отменами и задержками. Отмена рейсов приводит к существенным неудобствам для пассажиров и значительным затратам для авиакомпаний. В то же время имеется значительный объем данных об истории полетов в аэропортах по всей стране. Прогнозирование отмены рейсов в конкретном аэропорту может быть полезным как для авиакомпаний, так и для пассажиров, позволяя им принимать обоснованные решения и минимизировать сбои. Из-за естественной графообразной конфигурации аэропортов как вершин и рейсов, действующих как ребра между ними, это делает проблему подходящей для решения с помощью графовых нейронных сетей. Цель этого проекта — создать модель GNN, которая использует данные об отмене рейсов для аэропортов в США и связанные с ними функции и может прогнозировать ожидаемый процент отмен рейсов в данном аэропорту и году. Оттуда мы демонстрируем полезность такой информации, создавая программу, которая находит маршрут между двумя разными аэропортами с минимальной вероятностью отмены.
Вот ссылки на наш проект, коллаборации и заархивированные данные для более удобного изучения материала:
- Папка проекта и данные: https://drive.google.com/drive/folders/1SLo0WaF4NC7ioooAljRDUlthAjJkMiX5?usp=sharing
Набор данных:
Источники данных и функции
Наш граф состоит из аэропортов в качестве узлов и рейсов между аэропортами в качестве ребер. Мы используем данные узлов и краев из 2 разных источников данных, все из которых происходят из OpenFlights, проекта с открытым исходным кодом, который содержит информацию об аэропортах и коммерческих рейсах за разные промежутки времени [1].
Данные о расположении, работе, отменах и отклонениях аэропортов за 2004–2014 годы:
- https://github.com/pdsmith1223/datasets/blob/master/Airport_operations.csv
- https://github.com/pdsmith1223/datasets/blob/master/airport_cancellations.csv
- https://github.com/pdsmith1223/datasets/blob/master/airports.csv
Внутренние рейсы в США с 1990 по 2009 год:
Для наших функций узла мы используем текущую информацию об аэропорте, а также прошлую информацию. Текущая информация содержит количество отклонений от аэропорта, информацию о выходе на посадку, местоположение аэропорта, информацию о рейсах такси и задержках в аэропорту. Прошлые данные содержат информацию об этикетках за предыдущие 5 лет, поскольку эта информация будет доступна в течение интересующего года.
Для наших функций Edge/Flight мы включили информацию о количестве пассажиров в год, количестве рейсов, выполненных за год, количестве мест, доступных в год, и средней дальности полета.
Мы нормализуем функции узлов и ребер, а также метки узлов, используя наш обучающий набор.
Предварительная обработка данных
Мы взяли самые последние данные из наших существующих наборов данных и использовали данные аэропорта за 2014 год с данными о полетах за 2009 год. Хотя в США 383 аэропорта, у нас были данные по 61 аэропорту (29 крупных аэропортов, 26 аэропортов среднего размера и 6 аэропортов малого размера). Было доступно 123 443 рейса. Однако это приведет к чрезмерному сглаживанию, поскольку практически каждый узел будет находиться на расстоянии не более 2 переходов от любого другого узла в сети. Чтобы исправить это, мы выбрали самый загруженный рейс из каждого аэропорта вылета и рассмотрели только рейсы с пассажиропотоком более 200 000 пассажиров в год. Мы преобразуем каждое ребро в ненаправленное ребро и объединяем повторяющиеся ребра. Благодаря нашей предварительной обработке мы получаем неориентированный полносвязный граф. Наш окончательный граф имеет 61 узел, 27 признаков узла, 746 ребер, 4 признака края и среднюю степень 12,23.
В целях визуализации мы увеличили годовой порог пассажиропотока до 700 000, чтобы уменьшить количество присутствующих ребер. Цвет узлов представляет собой метку с более темными цветами, представляющими больший процент отмененных рейсов. Вот изображение нашего графика, на котором мы использовали широту и долготу каждого аэропорта, чтобы наложить их поверх карты США в проекции Меркатора:
Сплиты
Наш набор данных содержит только часть данных, которые ранее были публично извлечены. Это означает, что, индуктивно разделив наш график, мы можем обобщить его на аэропорты США, которых нет в нашем наборе данных, а также на международные аэропорты.
Мы используем разделение 70/10/20% между нашими обучающими, проверочными и тестовыми наборами. Мы обнаружили, что, поскольку у нас есть небольшое количество узлов, случайное разделение узлов вызывает плохое поведение при обучении, а разделение оказывает большое влияние на производительность. Мы выдвинули гипотезу и протестировали равномерное разделение набора данных по размеру аэропорта (например, обучающий набор содержит случайные 70% крупных аэропортов, случайные 70% средних аэропортов, случайные 70% малых аэропортов). Тренировочное поведение с этим разделением было гораздо более репрезентативным для возможностей модели.
Вот наша реализация для равномерного разделения по размеру аэропорта:
# Split evenly wrt airport size # Get indices of large, medium and small airports larges = [] mediums = [] smalls = [] smallandnonhub_airports = smallhub_airports + nonhub_airports for airport, airport_id in airport_to_id.items(): if airport in largehub_airports: larges.append(1) else: larges.append(0) if airport in mediumhub_airports: mediums.append(1) else: mediums.append(0) if airport in smallandnonhub_airports: smalls.append(1) else: smalls.append(0) larges = np.array(larges) mediums = np.array(mediums) smalls = np.array(smalls) larges_inds = np.argpartition(larges, -larges.sum())[-larges.sum():] np.random.shuffle(larges_inds) mediums_inds = np.argpartition(mediums, -mediums.sum())[-mediums.sum():] np.random.shuffle(mediums_inds) smalls_inds = np.argpartition(smalls, -smalls.sum())[-smalls.sum():] np.random.shuffle(smalls_inds) # Split data train_p = 0.7 val_p = 0.1 test_p = 0.2 larges_train_inds = larges_inds[:int(train_p*len(larges_inds))] larges_val_inds = larges_inds[int(train_p*len(larges_inds)):int(train_p*len(larges_inds))+int(val_p*len(larges_inds))+1] larges_test_inds = larges_inds[int(train_p*len(larges_inds))+int(val_p*len(larges_inds))+1:] mediums_train_inds = mediums_inds[:int(train_p*len(mediums_inds))] mediums_val_inds = mediums_inds[int(train_p*len(mediums_inds)):int(train_p*len(mediums_inds))+int(val_p*len(mediums_inds))+1] mediums_test_inds = mediums_inds[int(train_p*len(mediums_inds))+int(val_p*len(mediums_inds))+1:] smalls_train_inds = smalls_inds[:int(train_p*len(smalls_inds))] smalls_val_inds = smalls_inds[int(train_p*len(smalls_inds)):int(train_p*len(smalls_inds))+int(val_p*len(smalls_inds))+1] smalls_test_inds = smalls_inds[int(train_p*len(smalls_inds))+int(val_p*len(smalls_inds))+1:] train_inds = np.concatenate((larges_train_inds, mediums_train_inds, smalls_train_inds)) val_inds = np.concatenate((larges_val_inds, mediums_val_inds, smalls_val_inds)) test_inds = np.concatenate((larges_test_inds, mediums_test_inds, smalls_test_inds)) # Create split masks num_nodes = len(train_inds) + len(val_inds) + len(test_inds) train_mask = torch.zeros(num_nodes, dtype=torch.bool) val_mask = torch.zeros(num_nodes, dtype=torch.bool) test_mask = torch.zeros(num_nodes, dtype=torch.bool) for idx in train_inds: train_mask[idx] = True for idx in val_inds: val_mask[idx] = True for idx in test_inds: test_mask[idx] = True data_split = data data_split.train_mask = train_mask data_split.val_mask = val_mask data_split.test_mask = test_mask
Мы можем увидеть нашу схему разделения на следующих графиках. Наши графики окрашены в соответствии с размером аэропорта (большой, средний, маленький), причем самый темный цвет соответствует крупным аэропортам:
Объяснение модели:
Обзор GNN
Графовые нейронные сети (GNN) — это класс моделей машинного обучения, которые были разработаны специально для обработки данных, структурированных в виде графов. Благодаря многочисленным приложениям, от биологии до финансов, GNN становятся все более популярными в последние годы. Недавнее развитие исследований в этой области привело к тому, что GNN стали более надежными и устойчивыми к различным структурам графов, превзойдя предыдущие современные модели во многих тестах. Текущие модели используют различные функции передачи сообщений и агрегации для захвата как локальной, так и глобальной структуры графа, что позволяет GNN хорошо выполнять задачи прогнозирования на уровне узлов, ребер и графов. Простые примеры GNN включают функцию передачи сообщений MLP и функцию агрегирования среднего пула.
Для каждого узла в нашем графе мы получаем дерево вычислений, состоящее из соседей узла и соседей соседа. Это дерево вычислений может стать сколь угодно большим, но глубокое дерево вычислений может привести к чрезмерному сглаживанию, явлению, когда деревья вычислений многих узлов выглядят одинаково и, таким образом, вычисляют одинаковое вложение для этих узлов.
Pytorch-Geometric (PyG) — популярный фреймворк Python, построенный на основе Pytorch для работы с графовыми нейронными сетями [2]. Это позволяет нам полностью определять граф, создавать модели GNN, предоставлять существующие реализации для известных моделей GNN и поддерживать общие преобразования графов для сквозного обучения GNN.
Модель GraphSAGE
GraphSAGE [3] — это конкретная модель GNN, которая хорошо подходит для этой задачи, поскольку она может эффективно изучать вложения для узлов в графе (в данном случае аэропортов) путем агрегирования информации из их локального окружения (близлежащие аэропорты и другие соответствующие функции). ).
Модель GraphSAGE подходит для этой задачи по нескольким причинам:
- Пространственные отношения. На эффективность аэропортов и отмен рейсов могут влиять пространственные факторы, такие как задержки на земле или дата рейса. GraphSAGE может фиксировать эти отношения, собирая информацию из близлежащих аэропортов, что позволяет модели изучать пространственные закономерности, которые могут иметь отношение к отменам рейсов.
- Масштабируемость: GraphSAGE — это алгоритм индуктивного обучения, который может обобщаться на невидимые узлы, что означает, что он может обрабатывать графы различных размеров и структур. Это особенно полезно для задачи прогнозирования отмены рейсов, поскольку со временем могут добавляться новые аэропорты, а структура графа может изменяться.
- Изучение функций: GraphSAGE может изучать представления узлов, которые включают как локальную, так и глобальную информацию, что имеет решающее значение для этой задачи. Локальная информация, такая как отдельные атрибуты аэропорта, а также глобальная информация, такая как взаимосвязь между аэропортами, могут влиять на количество отмен, и GraphSAGE может научиться уравновешивать эти факторы при прогнозировании.
Алгоритм GraphSAGE зацикливается на K итераций. На каждой итерации он проходит через каждый узел, применяет сообщение и функцию агрегирования к узлу на основе его соседей. Затем он нормализует векторы. Единственное различие между нашей реализацией и приведенным выше псевдокодом заключается в том, что мы используем функцию агрегатора средних значений в нашей модели GraphSAGE, заменяя строки 4 и 5 в приведенном выше псевдокоде. Сначала он объединяет вектор целевого узла с каждым из узлов в его окрестности. Затем он берет поэлементное среднее значение каждого вектора, который он объединяет, умножает их на весовую матрицу W и применяет к результату сигмовидную функцию.
Используя модель GraphSAGE для прогнозирования ожидаемого процента отмены рейсов для данного аэропорта, мы можем использовать возможности графовых нейронных сетей для анализа сложных взаимосвязей между аэропортами и окружающей их средой. Это может привести к более точным прогнозам и помочь заинтересованным сторонам в авиационной отрасли лучше прогнозировать и смягчать последствия отмены рейсов.
Мы демонстрируем нашу реализацию агрегации GraphSAGE:
# Self Created GraphSAGE CONV Model # class MyGraphSageCONV(MessagePassing): def __init__(self, in_channels, out_channels, normalize = True, bias = False, **kwargs): super(MyGraphSageCONV, self).__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.normalize = normalize # Layers for message and update functions self.lin_l = nn.Linear(in_features=in_channels, out_features=out_channels, bias=True) self.lin_r = nn.Linear(in_features=in_channels, out_features=out_channels, bias=True) self.reset_parameters() def reset_parameters(self): self.lin_l.reset_parameters() self.lin_r.reset_parameters() def forward(self, x, edge_index, size = None): #Message passing and post-processing out = self.lin_l(x) x_prop = self.propagate(edge_index=edge_index, x=(x, x), size=size) x_prop = self.lin_r(x_prop) out = out + x_prop if self.normalize: out = F.normalize(out, p=2) return out def message(self, x_j): #Message Function out = x_j return out def aggregate(self, inputs, index, dim_size = None): #Aggregate function node_dim = self.node_dim out = torch_scatter.scatter(src=inputs, index=index, dim=node_dim, dim_size=dim_size, reduce='mean') return out
Используя либо описанную выше реализацию GraphSAGE, либо класс SAGEConv от Pytorch-Geometric (с агрегатором среднего), мы можем построить GNN, наложив друг на друга два слоя SAGEConv, за которыми следует двухслойный MLP, чтобы углубить модель без чрезмерного сглаживания. . Наш GraphSAGE GNN также использует выборку соседей [3] (используя схему выборки (15, 5) для двухслойного MLP) для дальнейшего предотвращения чрезмерного сглаживания и регуляризации (нормы L2 и отсева [4]) для предотвращения переобучения.
Вот код для нашей реализации GraphSAGE:
class GraphSAGENeighborSampling(torch.nn.Module): """GraphSAGE Model with Neighbor Sampling.""" def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.1, weight_decay=5e-3, lr=1e-3): super().__init__() self.dropout = dropout #GraphSAGE layers self.conv1 = MyGraphSageCONV(in_dim, hidden_dim) self.conv2 = MyGraphSageCONV(hidden_dim, hidden_dim) # self.conv3 = SAGEConv(hidden_dim, hidden_dim) #uncomment for 3 layer GraphSAGE self.convs = [self.conv1, self.conv2] # add self.conv3 for 3 layers GraphSAGE #Post-processing MLP layer self.post_mp = torch.nn.Sequential( torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.Linear(hidden_dim, out_dim)) self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay) def forward(self, x, adjs, train=True, return_embedding=False): # Use Neighbor Sampling for Training if train: for i, (edge_index, _, size) in enumerate(adjs): x_target = x[:size[1]] x = self.convs[i]((x, x_target), edge_index) x = F.elu(x) x = F.dropout(x, p=self.dropout) # Evaluate on Entire Graph else: x = self.conv1(x, adjs) x = F.elu(x) x = F.dropout(x, p=self.dropout) x = self.conv2(x, adjs) x = F.elu(x) x = F.dropout(x, p=self.dropout) #Save GraphSAGE embedding embedding = x x = self.post_mp(x) if return_embedding: return x, F.log_softmax(x, dim=-1), embedding else: return x, F.log_softmax(x, dim=-1)
Мы начинаем с 31 функции для каждого из 61 узла, используя полный пакет. Здесь мы можем увидеть визуализацию полученных размеров после каждого шага:
Почему бы не использовать GCN или GAT?
Наша задача связана с индуктивностью (как объяснено в разделе набора данных), и сверточная сеть графа (GCN) [5] не будет работать хорошо, потому что она требует просмотра каждого узла во время обучения. Мы не рассматривали использование сети внимания графа (GAT) [6], потому что мы хотим явно указать модели, какие узлы являются более важными, с помощью краевых функций, а не заставлять модель неявно изучать веса внимания.
Оценка
Обучение
Мы используем класс PyG NeighborSampler в качестве нашего загрузчика данных. NeighborSampler позволяет нам случайным образом выбирать различное количество соседей на разных уровнях модели. Чтобы оценить производительность модели, мы используем коэффициент детерминации (R²) между нашими прогнозируемыми коэффициентами отмены и истинными значениями, указанными в следующей формуле:
Наш конвейер обучения можно увидеть ниже, где мы используем оптимизатор Адама (определенный в модели), среднеквадратичную потерю ошибок, снижение веса, отсев и раннюю остановку (путем записи лучшей модели на основе проверки r²).
train_loader = NeighborSampler(data.edge_index, node_idx=data.train_mask, sizes=[10, 5], batch_size=64, shuffle=True, num_workers=1) def train_gsns(model, data, nepochs=1000): """Train a GraphSAGE model with Neighbor Sampling and return the trained final model, best model, and recorded metrics.""" metrics = {"epoch": [], "train_r2": [], "train_loss": [], "val_r2": [], "val_loss": []} criterion = torch.nn.MSELoss() optimizer = model.optimizer epochs = nepochs best_val_r2 = -10 best_model = None model.train() for epoch in range(epochs+1): ys = [] preds = [] total_loss = total_examples = 0.0 # Training for batch_size, n_id, adjs in train_loader: adjs = [adj.to(device) for adj in adjs] optimizer.zero_grad() out, log_out = model(data.x[n_id], adjs) loss = criterion(out.flatten(), data.y[n_id[:batch_size]].flatten()) total_loss += loss total_examples += out.numel() ys.append(data.y[n_id[:batch_size]].flatten()) preds.append(out.flatten()) loss.backward() optimizer.step() # Compute Training Metrics train_loss = total_loss / total_examples preds = torch.cat(preds, dim=0) ys = torch.cat(ys, dim=0) train_r2, _, _ = calculate_metric(preds, ys) # Compute Validation Metrics test_out, test_logout = model(data.x, data.edge_index, train=False) val_loss = criterion(test_out[data.val_mask].flatten(), data.y[data.val_mask]) val_r2, _, _ = calculate_metric(test_out[data.val_mask], data.y[data.val_mask]) # Print metrics every 100 epochs if(epoch % 100 == 0): print(f'Epoch {epoch:>3} | Train Loss: {train_loss:.3f} | Train r^2: ' f'{train_r2:>6.2f} | Val Loss: {val_loss:.2f} | ' f'Val r^2: {val_r2:.2f}') # Save metrics for plotting metrics["epoch"].append(epoch) metrics["train_r2"].append(train_r2) metrics["train_loss"].append(train_loss) metrics["val_r2"].append(val_r2) metrics["val_loss"].append(val_loss) # Save best model if epoch > int(epochs / 2) and val_r2 > best_val_r2: #epoch > int(epochs / 2) and best_val_r2 = val_r2 best_model = copy.deepcopy(model) return model, best_model, metrics gsagens = GraphSAGENeighborSampling(data.num_features, 64, 1, dropout=0.1, weight_decay=5e-3, lr=1e-3).to(device) gsagens, best_gsagens, metrics_gsagens = train_gsns(gsagens, data, nepochs=1000)
Полученные результаты
Мы провели несколько экспериментов с использованием разных гиперпараметров, и наши основные выводы можно увидеть в таблице ниже:
Для модели GCN мы видим, что, несмотря на увеличивающийся показатель r² обучения, она не может обобщить невидимые узлы в наборе проверки и тестирования. Мы также видим, что предварительно созданные слои GraphSAGE от Pytorch-Geometric работают лучше, чем наша собственная реализация, благодаря лучшим схемам оптимизации. Наши лучшие результаты при использовании 2-Layer GraphSAGE с отсевом имеют r² 0,85, что указывает на сильную связь между прогнозом модели и правильными метками. Таким образом, мы эффективно продемонстрировали, как модель GraphSAGE может предсказывать частоту отмены рейсов в аэропортах.
Мы также видим, что модель сразу же сглаживается при добавлении третьего слоя GraphSAGE. Это даже происходит, когда мы применяем выборку соседей. Это подтверждается просмотром вложений для всех узлов, которые очень похожи друг на друга. С нашей структурой графа почти все узлы находятся на расстоянии 2 переходов друг от друга из-за крупных узловых аэропортов, принимающих рейсы в большинство аэропортов и из них. Таким образом, двухуровневая модель работает хорошо, но трехслойная модель предсказывает, что все узлы будут одинаковыми.
Мы обнаружили, что двухслойная модель GraphSAGE с выборкой соседей и регуляризацией (уменьшение веса и отсев) работает лучше всего. Вероятно, это связано с тем, что предыдущие модели легко подгонялись к нашим обучающим примерам, а также не позволяли узлам с большими перекрывающимися структурами соседства (например, два аэропорта на Гавайях) разрабатывать одно и то же вложение.
Вот наши обучающие графики для этой модели:
Визуализация
После обучения модели мы можем сравнить вложения 2D TSNE [7] наших векторов входных признаков с нашими обученными вложениями GraphSAGE. Мы будем использовать sklearn-многообразие.TSNE для этого:
# Plot TSNE for Trained Embeddings # # Find TSNE Embeddings tsne = TSNE(2, perplexity=15, verbose=3, n_iter=2000, n_iter_without_progress=300) tsne_proj = tsne.fit_transform(gsagens_embeddings) tx, ty = tsne_proj[:,0], tsne_proj[:,1] # Normalize labels to [0.2, 1.0] for visualization max_y = data.y.max().detach().cpu().numpy() min_y = data.y.min().detach().cpu().numpy() data_y = data.y.detach().cpu().numpy() t_max = 1.0 t_min = 0.2 min_max_y = (data_y - min_y) / (max_y - min_y) * (t_max - t_min) + t_min min_max_y_train = min_max_y * data.train_mask.detach().cpu().numpy() min_max_y_val = min_max_y * data.val_mask.detach().cpu().numpy() min_max_y_test = min_max_y * data.test_mask.detach().cpu().numpy() color_map = mpl.colormaps['Reds'] node_color = [color_map(label) for label in min_max_y] node_aiport = [id_to_airport[airportid] for airportid in range(len(min_max_y))] # Plot TSNE with label as color fig, ax = plt.subplots(figsize=(8,8)) ax.scatter(tx, ty, c=node_color, alpha=0.5) for i, airport in enumerate(node_aiport): if i in label_nodes: ax.annotate(airport, (tx[i], ty[i])) plt.show()
Вложения TSNE окрашены в соответствии с их меткой более темными цветами, что указывает на большой процент отмененных рейсов. Мы также пометили несколько аэропортов их кодом FAA для ясности и сравнения двух графиков TSNE.
Во встраивании TSNE обученных вложений GraphSAGE мы видим, что существует четкий градиент цветов, сигнализирующий о том, что модель изучает четкие взаимосвязи между свойствами узла и их метками (эквивалентно, если мы определяем бины/пороги для меток, чтобы превратиться это в проблему классификации, мы увидим, что кластеры формируются на основе меток классов). С другой стороны, встраивание TSNE для векторов входных признаков показывает менее четкие закономерности.
Оптимизация маршрута полета
Чтобы показать пример полезности наших прогнозов, мы демонстрируем алгоритм, который дает маршрут между двумя разными аэропортами с минимальной вероятностью отмены. Алгоритм предполагает независимость между отменами в каждом аэропорту. Во-первых, мы строим ориентированный граф всех аэропортов и рейсов между ними в виде узлов и ребер соответственно. Наша цель — найти минимальный путь между нашей исходной вершиной и точкой назначения с весами ребер, представляющими вероятность того, что наш рейс будет отменен. Обычно эту проблему решают с помощью стандартного алгоритма поиска кратчайшего пути, такого как алгоритм Дейкстры. Обычно это работает, но нам нужно внести изменения, так как мы не складываем веса ребер вместе в заданном пути, а умножаем их, потому что имеем дело с вероятностями. Поскольку логарифм является монотонной функцией и
мы можем заменить каждый вес a на log(a) в алгоритме Дейкстры, чтобы найти оптимальный путь. Вот пример игрушки:
Предположим, мы начинаем в Портленде (PDX) и должны добраться до Чикаго (ORD), а прямых рейсов в этот день нет. Наши варианты: сначала лететь в Солт-Лейк-Сити (SLC) или Сан-Франциско (SFO), а затем из любого из них в Чикаго. Предполагая равную вероятность отмены входящих и исходящих рейсов, предположим, что модель предсказывает, что вероятность отмены в PDX составляет 3 %, в SFO — 4 %, в SLC — 4,3 %, а в ORD — 1 %. Тогда путь через SLC будет log(0,03 · 0,043) + log(0,043 · 0,01) = -6,256, а путь через SFO будет log(0,03 · 0,04) + log(0,04 · 0,01) = -6,319. Путь SFO имеет более низкий балл и, следовательно, более низкую вероятность отмены, поэтому мы должны выбрать этот путь. Этот пример может показаться довольно очевидным с таким коротким путем и таким небольшим количеством вариантов перелета, но преимущества становятся более очевидными, если мы делаем, скажем, всемирную логистику доставки грузов с гораздо большим количеством вариантов перелета и несколькими возможными начальными и конечными пунктами.
Заключение
Наши результаты показывают, что графовые нейронные сети могут быть успешно применены к данным об аэропортах и рейсах. Мы эффективно создали модель GraphSAGE, которая способна прогнозировать процент отмен рейсов в заданном аэропорту и в заданное время. Благодаря этому мы смогли показать важность наших результатов в реальных сценариях логистики. Мы надеемся, что это послужит ценным ресурсом для обучения и визуализации результатов графовых нейронных сетей.
В будущем мы хотели бы расширить набор данных, чтобы охватить все аэропорты США, а также международные аэропорты. Это позволило бы нам экспериментировать с трансдуктивными моделями, создавать более глубокие сети и работать с действительно случайными разбиениями. Мы также хотели бы посмотреть, работает ли наш подход с другими ярлыками интересов, такими как задержки. Что касается структуры графа, мы хотели бы изучить моделирование этой проблемы как неоднородного графа, где рейсы и аэропорты являются узлами, чтобы мы могли использовать прогнозирование ссылок для прогнозирования того, откуда или куда направляются рейсы. Наконец, мы хотели бы изучить влияние внесения архитектурных изменений, таких как изменение функции агрегирования.
Рекомендации
[1] Источник данных OpenFlights. Получено с https://openflights.org/data.html.
[2] Фей, М., и Ленссен, Дж. Э. (2019). Быстрое изучение представления графа с помощью PyTorch Geometric. Препринт arXiv arXiv: 1903.02428.
[3] Уильям Л. Гамильтон, Рекс Ин и Юре Лесковец. 2017. Обучение индуктивному представлению на больших графах. В материалах 31-й Международной конференции по системам обработки нейронной информации (NIPS’17). Curran Associates Inc., Ред-Хук, Нью-Йорк, США, 10:25–10:35.
[4] Нитиш Сривастава, Джеффри Хинтон, Алекс Крижевский, Илья Суцкевер и Руслан Салахутдинов. 2014. Dropout: простой способ предотвратить переоснащение нейронных сетей. Дж. Мах. Учиться. Рез. 15, 1 (январь 2014 г.), 1929–1958 гг.
[5] Zhang, H., Lu, G., Zhan, M. et al. Полууправляемая классификация сверточных сетей графов с ранговыми ограничениями Лапласа. Neural Process Lett 54, 2645–2656 (2022). https://doi.org/10.1007/s11063-020-10404-7
[6] Насколько внимательны сети графического внимания?, Шакед Броуди, Ури Алон, Эран Яхав, arXiv: 2105.14491, 2021 г.
[7] Ваттенберг и др., Как эффективно использовать t-SNE, Distill, 2016. https://doi.org/10.23915/distill.00002