Gorio Tech Blog search

Pytorch Geometric Message Passing 설명

|

목차

본 글에서는 Pytorch Geometric에서 가장 기본이 되는 MessagePassing class에 대해 설명하고, 활용 방법에 대해서도 정리하도록 할 것이다.

GNN의 여러 대표 알고리즘이나 torch_geometric.nn의 대표 layer에 대한 간략한 설명을 참고하고 싶다면 Github을 참고해도 좋다.


Message Passing 설명

1. Background

GNN은 대체적으로 Neighborhood Aggregation & Combine 구조의 결합으로 구성되는데, 이를 또 다른 말로 표현하면 Message Passing이라고 할 수 있다.

특정 node를 설명하기 위한 재료로 그 node의 neighborhood의 특징을 모으는 과정이 바로 Message Passing인 것이다.

\[x^{l+1}_i = \gamma^l (x_i^l, AGG_{j \in \mathcal{N}_i} \phi^l (x_j^l, ...) )\]

디테일의 차이는 있지만 GNN의 많은 대표 알고리즘들은 각자의 Message Passing 논리가 있고, Pytorch Geometric에서는 이러한 scheme을 효과적으로 구현하기 위해 MessagePassing이라는 class를 제공하고 있다.

Source 코드는 이 곳에서 확인할 수 있다.

MessagePassing은 torch.nn.Module을 상속받았기 때문에 이 class를 상속할 경우 다양한 Graph Convolutional Layer를 직접 구현할 수 있게 된다. (물론 굉장히 많은 Layer가 이미 기본적으로 제공되고 있다.)


2. MessagePassing 구조

MessagePassing을 생성할 때는 아래와 같은 사항을 정의해주어야 한다.

MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)

위 예시는 기본 값을 나타낸 것으로, 설계에 따라 aggregation 방법을 mean/max 등으로 바꾸거나 할 수 있다.

source code를 보면 여러 메서드가 정의되어 있는 것을 알 수 있는데, 가장 중요한 메서드는 propagate이다. 이 메서드는 내부적으로 message, aggregate, update를 자동적으로 call한다. 앞서 보았던 식으로 설명하면 이해가 좀 더 편한데,

\[x^{l+1}_i = \gamma^l (x_i, AGG_{j \in \mathcal{N}_i} \phi^l (x_j^l, ...) )\]

위 식에서 먼저 $\phi$ 에 해당하는 부분이 message이다. 이 영역은 대체적으로 미분 가능한 MLP로 구성하는데, 이웃 node x_j 의 정보를 어떻게 가공하여 target node x_i 에 전달할지를 정의하는 부분이다. 참고로 i, j notation은 Pytorch Geometric 전체에서 명확히 구분하고 있으니 임의로 바꾸는 것을 추천하지는 않는다.

식에서 $AGG$ 라고 되어 있는 부분이 당연히 aggregate를 의미한다. 이웃 node의 특성을 모으는 과정이다. 여러 방법이 있으나 일단은 간단하게 sum을 생각해보자.

$\gamma$ 함수가 update 부분을 담당하게 된다. 이전 layer의 값이 현재 layer로 업데이트 되는 것이다.


3. Code 설명

MessagePassing 코드를 보면 상단에 아래와 같은 부분이 있다.

class MessagePassing(torch.nn.Module):
    special_args: Set[str] = {
        'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size',
        'size_i', 'size_j', 'ptr', 'index', 'dim_size'
    }

    def __init__(self, aggr: Optional[str] = "add",
                 flow: str = "source_to_target", node_dim: int = -2):

        super(MessagePassing, self).__init__()

        self.aggr = aggr
        assert self.aggr in ['add', 'mean', 'max', None]

        self.flow = flow
        assert self.flow in ['source_to_target', 'target_to_source']

        self.node_dim = node_dim

        # 함수를 검사하여 인자를 OrderedDict 형태로 취함
        # pop_first=True 이면 첫 인자는 버림
        self.inspector = Inspector(self)
        self.inspector.inspect(self.message)   # message 메서드
        ...

이후에도 확인하겠지만, 이 class를 구현할 때 여러 메서드들 작성해야 하는데, 대부분 additional argument를 허용하는 구조로 되어 있다. 그래서 MessagePassing class에서는 이러한 인자들을 inspector를 통해 제어한다.

import re
import inspect
from collections import OrderedDict
from typing import Dict, List, Any, Optional, Callable, Set

class Inspector(object):
    def __init__(self, base_class: Any):
        self.base_class: Any = base_class
        self.params: Dict[str, Dict[str, Any]] = {}

    def inspect(self, func: Callable,
                pop_first: bool = False) -> Dict[str, Any]:
        params = inspect.signature(func).parameters
        params = OrderedDict(params)
        if pop_first:
            params.popitem(last=False)
        self.params[func.__name__] = params

    def keys(self, func_names: Optional[List[str]] = None) -> Set[str]:
        keys = []
        for func in func_names or list(self.params.keys()):
            keys += self.params[func].keys()
        return set(keys)

    def __implements__(self, cls, func_name: str) -> bool:
        if cls.__name__ == 'MessagePassing':
            return False
        if func_name in cls.__dict__.keys():
            return True
        return any(self.__implements__(c, func_name) for c in cls.__bases__)

    def implements(self, func_name: str) -> bool:
        return self.__implements__(self.base_class.__class__, func_name)

    def types(self, func_names: Optional[List[str]] = None) -> Dict[str, str]:
        out: Dict[str, str] = {}
        for func_name in func_names or list(self.params.keys()):
            func = getattr(self.base_class, func_name)
            arg_types = parse_types(func)[0][0]
            for key in self.params[func_name].keys():
                if key in out and out[key] != arg_types[key]:
                    raise ValueError(
                        (f'Found inconsistent types for argument {key}. '
                         f'Expected type {out[key]} but found type '
                         f'{arg_types[key]}.'))
                out[key] = arg_types[key]
        return out

    def distribute(self, func_name, kwargs: Dict[str, Any]):
        # func_name = 예) 'message'
        # kwargs = coll_dict
        # inspector.params['message']에 있는 argument들을 불러온 뒤
        # 이들에게 해당하는 데이터를 coll_dict에서 가져옴
        out = {}
        for key, param in self.params[func_name].items():
            data = kwargs.get(key, inspect.Parameter.empty)
            if data is inspect.Parameter.empty:
                if param.default is inspect.Parameter.empty:
                    raise TypeError(f'Required parameter {key} is empty.')
                data = param.default
            out[key] = data
        return out

아래 코드는 어떤 함수가 갖고 있는 argument들을 불러오는 것을 의미한다.

params = inspect.signature(func).parameters

예를 들어 아래 코드를 실행하면,

import inspect
from collections import OrderedDict

def func(a='OH', b=7, *args, **kwargs):
    pass

params = inspect.signature(func).parameters
params = OrderedDict(params)

다음과 같은 결과를 확인할 수 있다.

OrderedDict([('a', <Parameter "a='OH'">), ('b', <Parameter "b=7">), ('args', <Parameter "*args">), ('kwargs', <Parameter "**kwargs">)])

위 사항을 인지하고 다시 코드를 보면,

class MessagePassing(torch.nn.Module):
    special_args: Set[str] = {
        'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size',
        'size_i', 'size_j', 'ptr', 'index', 'dim_size'
    }

    def __init__(self, aggr: Optional[str] = "add",
                 flow: str = "source_to_target", node_dim: int = -2):

        super(MessagePassing, self).__init__()

        self.aggr = aggr
        assert self.aggr in ['add', 'mean', 'max', None]

        self.flow = flow
        assert self.flow in ['source_to_target', 'target_to_source']

        self.node_dim = node_dim

        # 함수를 검사하여 인자를 OrderedDict 형태로 취함
        # pop_first=True 이면 첫 인자는 버림
        self.inspector = Inspector(self)
        self.inspector.inspect(self.message)
        self.inspector.inspect(self.aggregate, pop_first=True)
        self.inspector.inspect(self.message_and_aggregate, pop_first=True)
        self.inspector.inspect(self.update, pop_first=True)

        self.__user_args__ = self.inspector.keys(
            ['message', 'aggregate', 'update']).difference(self.special_args)
        self.__fused_user_args__ = self.inspector.keys(
            ['message_and_aggregate', 'update']).difference(self.special_args)

        # Support for "fused" message passing.
        # message_and_aggregate 메서드를 구현하면 self.fuse = True
        # self.inspector.base_class.__dict__.keys()에서 확인 가능
        self.fuse = self.inspector.implements('message_and_aggregate')

위 과정이 여러 메서드의 인자들을 수집한 후 이를 OrderedDict에 저장하는 과정임을 알 수 있다.

코드를 밑바닥에서 보면 파악하는 속도가 느리기 때문에 실제 데이터를 바탕으로 이어서 확인해보도록 하겠다.

import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import torch_geometric.transforms as T

# Load Cora Dataset
dataset = 'Cora'
path = os.path.join(os.getcwd(), 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]

x, edge_index = data.x, data.edge_index

print(x.shape, edge_index.shape)
# (torch.Size([2708, 1433]), torch.Size([2, 10556]))

2708개의 node가 존재하고, 이 node들은 10556개의 edge를 통해 graph를 구성하고 있음을 알 수 있다. 초반에 확인한 flow=”source_to_target”는 edge_index의 첫 행은 source node, 두 번째 행은 target node로 구성되어 있다는 것을 의미한다.

이제 공식 문서의 예제에서 처럼 GCN Layer를 한 번 정의해보자.

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

conv = GCNConv(x.shape[1], 32)

몇 가지 사항에 대해 확인해 보자.

# conv.inspector.params
# = message_passing이 갖고 있는 method의 인자들을 일부(pop_item=True)를 제외하고 취한 것
# Ex) {'message': OrderedDict([('x_j', <Parameger "x_j">)])}
for param, item in conv.inspector.params.items():
    print(param, ': ', item)

# check
print(conv.__user_args__)
print(conv.__fused_user_args__)
print(conv.fuse)

"""
message :  OrderedDict([('x_j', <Parameter "x_j">), ('norm', <Parameter "norm">)])
aggregate :  OrderedDict([('index', <Parameter "index: torch.Tensor">), ('ptr', <Parameter "ptr: Optional[torch.Tensor] = None">), ('dim_size', <Parameter "dim_size: Optional[int] = None">)])
message_and_aggregate :  OrderedDict()
update :  OrderedDict()

# check
{'norm', 'x_j'}
set()
False
"""

step 3까지 모두 진행했다고 해보자. 그러면 propagate 메서드에서 수행할 작업은 edge_index, 즉 주어진 graph 구조에 맞게 x의 feature들을 통합하는 과정이 될 것이다.

propagate 메서드 상단 부분을 보자. 복잡함을 피하기 위해 fused version 부분은 제거하였다. (실제로 사용할 때는 message_and_aggregate를 구현하는 것이 좋을 때가 많다. 왜냐하면 불필요한 연산을 줄임으로써 속도를 개선하고 메모리 사용량을 줄일 수 있기 때문이다. 아예 구현하지 않으면 호출될 일이 없다.)

def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
    size = self.__check_input__(edge_index, size)

    # Run "fused" message and aggregation (if applicable).
    # (생략)

    # Otherwise, run both functions in separation.
    elif isinstance(edge_index, Tensor) or not self.fuse:
        coll_dict = self.__collect__(
            self.__user_args__, edge_index, size, kwargs)

위 메서드에서 kwargs 부분은 중요하다. 왜냐하면 MessagePassing 클래스는 내부적으로 message, aggregate, update에서 추가한 argument들(위 예시에서는 norm 같은 경우)을 propagate 메서드에서 사용할 수 있게 해두었기 때문이다. 그렇기 때문에 위 GCN Layer에서도 아래와 같이 추가적인 argument를 전달할 수 있는 것이다.

# Step 4-5: Start propagating messages.
return self.propagate(edge_index, x=x, norm=norm)

coll_dict는 이 프로세스를 통과하고 있는 주요 변수/데이터를 딕셔너리 형태로 저장한 것이다. 위 예시에서 coll_dict는 아래와 같은 형상을 하고 있다.

{'norm': tensor([0.2500, 0.2236, 0.2500,  ..., 0.5000, 0.2000, 0.2000]),
 'x_j': tensor([[-0.0106, -0.0185, -0.0095,  ...,  0.0051, -0.0180,  0.0261],
                 ...,
                [-0.0148, -0.0149, -0.0153,  ..., -0.0033, -0.0236,  0.0217]],
                 grad_fn=<IndexSelectBackward>),
 'adj_t': None,
 'edge_index': tensor([[   0,    0,    0,  ..., 2705, 2706, 2707],
                       [ 633, 1862, 2582,  ..., 2705, 2706, 2707]]),
 'edge_index_i': tensor([ 633, 1862, 2582,  ..., 2705, 2706, 2707]),
 'edge_index_j': tensor([   0,    0,    0,  ..., 2705, 2706, 2707]),
 'ptr': None,
 'index': tensor([ 633, 1862, 2582,  ..., 2705, 2706, 2707]),
 'size': [2708, None],
 'size_i': 2708,
 'size_j': 2708,
 'dim_size': 2708}

나머지 과정을 확인해 보자. 복잡한 설명을 피하기 위해 중간 부분은 생략하였다.

msg_kwargs = self.inspector.distribute('message', coll_dict)
out = self.message(**msg_kwargs)

# For `GNNExplainer`, we require a separate message and aggregate
# (생략)

aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
out = self.aggregate(out, **aggr_kwargs)

update_kwargs = self.inspector.distribute('update', coll_dict)
output = self.update(out, **update_kwargs)

위 과정은 message, aggregate, update을 차례로 실행하는 과정이다. 이 때 inspector.distribute 부분은 coll_dict에서 필요한 데이터를 불러오는 과정을 의미하는데, 내부적으로 실행 과정을 보면 아래와 같다.

1) func_name = ‘message`
2) arguments = inspector.params[func_name]
3) arguements에 해당하는 데이터를 coll_dict에서 가져옴

이를 통해 생성된 결과물은 아래 예시에서 확인할 수 있다.

print({a:b.shape for a, b in msg_kwargs.items()})
# {'x_j': torch.Size([13264, 32]), 'norm': torch.Size([13264])}

aggregate 메서드를 적용한 후의 결과물의 shape은 (2708, 32)가 되는데, 이는 (target x_i 수, out_channels)와 일치한다. 즉 이 결과물은 node x_i를 기준으로 aggregated 된 feature matrix인 것이다. update 메서드를 적용해서 최종 아웃풋을 얻을 수 있는데, 위 예시에서는 update 메서드를 수정하지 않았으므로 이전 단계의 결과물이 그대로 전달된다.

지금까지가 MessagePassing class에 대한 간략한 설명이었고, 추후에는 이를 응용하여 Custom Graph Convolutional Layer를 만드는 방법에 대해 포스팅하도록 할 계획이다.