Shang, Chao, Jie Chen, and Jinbo Bi. "Discrete Graph Structure Learning for Forecasting Multiple Time Series." arXiv preprint arXiv:2101.06861 (2021).
- Paper: Discrete Graph Structure Learning for Forecasting Multiple Time Series
- Code: https://github.com/chaoshangcs/GTS
Gumbel Distribution 참고: https://data-newbie.tistory.com/263
https://blog.evjang.com/2016/11/tutorial-categorical-variational.html
ICLR-21 논문이다.
이처럼
Gumbel Probability 를 쓴다고 되어있다.
우선 GSTModel (https://github.com/chaoshangcs/GTS/blob/main/model/pytorch/model.py) 의 Forward 부분에 다음과 같은 코드가 있다.
softmax를 쓰는 이유는 row 합이 1이되는 adjacency matrix를 만들기 위해서 인듯은 하다..
코드는 Pytorch로 되어있다.
def sample_gumbel(shape, eps=1e-20):
U = torch.rand(shape).to(device)
return -torch.autograd.Variable(torch.log(-torch.log(U + eps) + eps))
def gumbel_softmax_sample(logits, temperature, eps=1e-10):
sample = sample_gumbel(logits.size(), eps=eps)
y = logits + sample
return F.softmax(y / temperature, dim=-1)
def gumbel_softmax(logits, temperature, hard=False, eps=1e-10):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probabilitiy distribution that sums to 1 across classes
"""
y_soft = gumbel_softmax_sample(logits, temperature=temperature, eps=eps)
if hard:
shape = logits.size()
_, k = y_soft.data.max(-1)
y_hard = torch.zeros(*shape).to(device)
y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0)
y = torch.autograd.Variable(y_hard - y_soft.data) + y_soft
else:
y = y_soft
return y
논문에서는 When the temperature s → 0, Aij = 1 with probability θij and 0 with remaining probability. In practice, we anneal s progressively in training such that it tends to zero.
라고해서 점점 s-> 0으로 만든다고 하는것 같다.
근데 실제 코드에서는 temp 점점 조절하는 부분을 주석처리 해버린듯?
------
TF Gumbel Sample Code
def sample_gumbel(shape, eps=1e-20):
U = tf.random_uniform(shape, minval=0, maxval=1)
return -tf.log(-tf.log(U + eps) + eps)
def gumbel_softmax(logits, temperature, hard=False):
gumbel_softmax_sample = logits + sample_gumbel(tf.shape(logits))
y = tf.nn.softmax(gumbel_softmax_sample / temperature)
if hard:
k = tf.shape(logits)[-1]
y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keep_dims=True)),
y.dtype)
y = tf.stop_gradient(y_hard - y) + y
return y
댓글