본문 바로가기
MLOps

[MLOps, Acceleration, Jax] Google Jax 와 Flax Library

by ML_MJSHIN 2021. 12. 15.

  이번 포스팅 부터는 실제로 JAX를 사용한 ML 모델 구현 코드를 보면서 공부를 하고 이해하는 포스팅이 될 것 같습니다. JAX만을 가지고도 모델을 구성할 수 있지만 실제로 많은 라이브러리들이 존재하고 대표적으로 Haiku와 Flax 두 가지가 ML library로써 모델을 구축하고 학습하는 것에 큰 편리성을 제공해주는 라이브러리입니다.

 

 두 라이브러리의 차이가 약간 존재합니다. Flax는 optimizer와 mixed precision, training loop에 관한 함수들을 자체적으로 가지고 있어서 원하는 것을 가져다가 사용할 수 있습니다. 하지만 Haiku는 조금 더 모델 구성 자체에 대한 함수에 집중하여 optimizer 와 training loop 와 관련된 모델 외적인 부분들은 다른 라이브러리의 사용을 추천하고 이들에 의존하고 있습니다. 

 

 DeepMind는 Haiku를 사용하는 것으로 알려져 있고, Flax는 구글에서 만들어서 사용하는 것으로 알려져 있습니다. Github의 Star 숫자는 Flax가 조금 더 높네요.

 

  Flax는 설치법도 매우 간단하며 다음 링크에서 자세한 내용들을 찾아보실 수 있습니다. https://github.com/google/flax

 

GitHub - google/flax: Flax is a neural network library for JAX that is designed for flexibility.

Flax is a neural network library for JAX that is designed for flexibility. - GitHub - google/flax: Flax is a neural network library for JAX that is designed for flexibility.

github.com

 

 이번 포스팅부터는 저는 Flax를 기반으로 작업을 했었기 때문에 Flax 기반의 transformer 코드를 가져와서 공부하는 포스팅을 진행하게 될 것 같습니다. 

 

 Flax를 통해서 ML 모델을 구현하고자 할 때 기존의 Pytorch와는 약간씩 다른 부분들이 있어서 해당 부분들을 하나하나 짚어가면서 살펴보도록 하겠습니다.

 

 우선 Transformer 모델은 'Attention is all you need' 논문에서 소개되었던 가장 기본적은 Vanila Transformer 모델을 기반으로 하도록 하겠습니다. 

from flax import nn

class MaskedMultiHeadSelfAttention(nn.Module):
  n_heads : int = 8
  d_k : int = 64
  d_v : int = 64
  output_dim : int = 512 
  
  # Example of setup
  # def setup(self):
  #    self.param_dict = {} 
  #    self.decoder = nn.Dense( ~~~ )
  #    self.n_heads = 8

  @nn.compact  
  def __call__ (self, input: jnp.ndarray, mask : jnp.ndarray):
    heads = []

    for i in range(self.n_heads):
      # projection
      keys_i = nn.Dense(self.d_k, name=f"Head_{i}_keys")(input)
      query_i = nn.Dense(self.d_k, name=f"Head_{i}_query")(input)
      values_i = nn.Dense(self.d_v, name=f"Head_{i}_value")(input)
      
      # self-attention
      x = self.MaskedScaledDotProductAttention(keys_i, query_i, values_i, mask)
      heads.append(x)
      
    heads = jnp.concatenate(heads, axis=-1)
    output = nn.Dense(self.output_dim, name="output_layer")(heads)    
    return output

  def MaskedScaledDotProductAttention(self, keys : jnp.ndarray, query : jnp.ndarray, values : jnp.ndarray, mask : jnp.ndarray):
    x = jnp.einsum("...ik, ...jk -> ...ij", query, keys) / jnp.sqrt(self.d_k) # Q x K.T 
    x = nn.softmax(x)
    x = jnp.einsum("...jk, j -> ...jk", x, mask) # k 는 단어이므로 mask 가 0이면 단어 모두 0 처리
    x = jnp.matmul(x, values)
    return x

 아래는 Flax를 기반으로 구현한 Masked-Multi-Head-Attention 입니다. 다음의 코드에서 주의해서 살펴볼 부분은 다음과 같습니다.

  1. __init__ 함수가 없다
  2. __call__ 함수에서 Pytorch의 forward함수에 들어갈 내용을 구성한다

 사실 위의 두 가지가 Flax로 모델을 구현할 때 고려할 유일한 특징입니다. 우선 __init__함수가 있어야 할 곳에 data와 관련된 내용을 정리하고 있습니다. 그리고 self 와 같은 표현도 사용하고 있지 않습니다. 하지만 아래쪽을 보시면 이 부분에서 정의한 변수들을 self 를 붙여서 접근하고있는 것을 보실 수 있습니다.

 

 이렇게 작동이 가능한 이유는 Flax에서는 class 이름을 정의하고 바로 처음에 오는 영역에 대해서 Data Class 영역으로 고정하고 있습니다. 그리고 이 부분에서 정의되는 데이터는 내부적으로 __init__함수를 기반으로 class의 변수로 만들어 주고 있기 때문에 __init__함수가 없다고 생각하시면 될 것 같습니다.

 

 그러면 __init__ 함수에서 정의하던 여러 Dense layer 나 다른 곳에서 정의한 ML 모델을 변수에 저장할 수 없는 것인가 라고 생각하신다면 이를 위해서 setup 함수가 존재합니다.  아래의 코드에서는 예제로 달아두었습니다. setup에서 Data Class 파트의 변수를 정의해주면서 Data Class에서 아무런 변수를 선언해주지 않아도 되며, 원하는 Flax 레이어들에 대해서 정의하고 __call__ 부분에서 가져다가 사용만해도 상관 없습니다. 

 

 그런데 @nn.compact 라는 부분이 있습니다. 이것도 Flax를 통한 구현에서 중요한 부분입니다. 보통 setup 함수를 쓰는 경우 layer들을 미리 정의하기 위해서 사용하게 됩니다. 하지만 __call__함수 내에서 input -> output 의 data control flow 를 정의하면서 layer를 바로바로 정의 할 수도 있으며 이를 위해서 @nn.compact decorator를 사용하게 됩니다. 

 

 간단히 말해, Pytorch 처럼 __init__ 내에서 layer를 사전에 정의하고 싶다면 @nn.compact decorator를 사용하지 않고 정의 해주고 사용하는 디자인 패턴으로 구현 하시면 되며, 만약 사전에 정의하기 싫고 data flow를 그리면서 바로바로 layer를 정의해주고 싶다면 @nn.compact decorator를 활용하시면 됩니다. 

 

  이제 __call__함수의 내부를 살펴보도록 하겠습니다. call 함수의 내부에서 처음 유의하실 부분은 layer 선언부 입니다. Flax (Haiku도 마찬가지) 에서는 layer 선언시에 input_dimension에 대한 정보를 자동으로 추론해주기 때문에 모든 레이어 선언시에 파라미터로 입력해 주지 않습니다. self.d_k는 64로 nn.Dense 의 output dimension 입니다. 

 

  # 여기서 einsum 에 대해서 스터디 하는 친구와 얘기하다가 (...ik, ...jk -> ...ij)가 아니냐는 얘기가 나왔었는데 einsum 사용시에 (nik, njk -> nij) 라고 해주면 torch.bmm(A, B.permute(0, 2, 1)) 과 같은 효과로 B 행렬의 (위에선 keys)를 transpose 해주고 matrix multiplication을 수행한 효과를 알아서 내줍니다.

 

 다음으로는 위의 self-attention을 이용해서 encoder와 decoder를 구성하는 코드를 살펴보도록 하겠습니다. (Encoder와 Decoder를 합쳐서 작동하는 코드는 이 부분을 보고 볼 예정입니다. 기다려주세요 ㅎㅎ) 

class PositionwiseFeedForward(nn.Module):
  hidden_dim : int = 2048
  output_dim : int = 512

  @nn.compact
  def __call__(self, input):
    x = input
    x = nn.Dense(self.hidden_dim)(x)
    x = nn.relu(x)
    x = nn.Dense(self.output_dim)(x)
    return x
    
class EncoderBlock(nn.Module):
  emb_dim : int = 512
  n_heads : int = 8
  d_k     : int = 64
  d_v     : int = 64
  d_ff    : int = 2048
  
  @nn.compact
  def __call__(self, input):
    x = input
    x = MultiHeadSelfAttention(self.n_heads, self.d_k, self.d_v, self.emb_dim)(x)
    x = x + input
    x = nn.LayerNorm()(x)

    y = x.copy()
    y = PositionwiseFeedForward(self.d_ff, self.emb_dim)(y)
    y = y + x
    y = nn.LayerNorm()(y)

    return y

class DecoderBlock(nn.Module):
  emb_dim : int = 512
  d_k : int = 64
  d_v : int = 64
  d_ff : int = 2048
  n_heads : int = 8  

  @nn.compact
  def __call__(self, input : jnp.ndarray, encoding : jnp.ndarray,  mask : jnp.ndarray):
    x = input
    x = MaskedMultiHeadSelfAttention(n_heads=self.n_heads, d_k=self.d_k, d_v=self.d_v, output_dim=self.emb_dim)(x, mask)
    x = x + input
    x = nn.LayerNorm()(x)

    y = x.copy()
    y = MultiHeadAttention(n_heads=self.n_heads, d_k=self.d_k, d_v=self.d_v, output_dim=self.emb_dim)(encoding, y, encoding)
    y = y + x
    y = nn.LayerNorm()(y)

    z = y.copy()
    z = PositionwiseFeedForward(hidden_dim=self.d_ff, output_dim=self.emb_dim)(z)
    z = z + y
    z = nn.LayerNorm()(z)

    return z

  먼저 PositioniseFeedForward 부터 살펴보도록 하겠습니다. 위에서 본 것 처럼 Flax.linen 라이브러리의 Module을 상속받았기 때문에, Data Classes에 원하는 변수들을 선언해줍니다. 그리고 'Attention is all you need' 논문의 FF layer 처럼 Dense layer를 2개 붙여줍니다. 이때 위에서 말씀드린 것처럼 nn.Dense 에서는 output_dimension 만을 전달해주면 되므로 매우 간단하게 구현이 끝났습니다.

 

 다음으로는, Encoder 와 Decoder 입니다. 사실 두 부분은 Attention 부분에서 Mask가 있냐 없냐의 차이만이 존재하기 때문에 Encoder를 기준으로만 설명을 드리도록 하겠습니다. Data Class에서 원하는 configuration에 대해서 정의하는 것은 다른 클래스들과 동일합니다. 여기에서는 Layernorm이 Flax.linen의 라이브러리에 있으며 위와 같이 사용한다는 것만 보고 넘어가시면 나머지 부분은 너무 쉽게 이해하실 수 있을 것이라고 생각합니다.

 

 이제 위의 클래스들을 모두 합쳐서 Transformer 전체를 구성해보겠습니다.

class Transformer(nn.Module):
  embed_dim : int = 512
  output_dim : int = 512
  src_vocab_size : int = 10000
  trg_vocab_size : int = 10000
  max_length : int = 512 
  n_enc_layers : int = 6
  n_dec_layers : int = 6
  d_k : int = 64
  d_v : int = 64
  d_ff : int = 2048

  def setup(self):
    self.word_emb_src = nn.Embed(self.src_vocab_size, self.embed_dim)
    self.word_emb_trg = nn.Embed(self.trg_vocab_size, self.embed_dim)
    self.pos_emb = nn.Embed(self.max_length, self.embed_dim)
    self.encoder_layers = [EncoderBlock(name=f"EncoderBlock_{i}") for i in range(self.n_enc_layers)]
    self.decoder_layers = [DecoderBlock(name=f"DecoderBlock_{i}") for i in range(self.n_dec_layers)]
    self.fc_out = nn.Dense(self.output_dim)

  def __call__(self, src_input, trg_input, mask):
    enc_out = self.encoder(src_input)
    dec_out = self.decoder(trg_input, enc_out, mask)
    return dec_out
  
  def encoder(self, input):
    pos = jnp.arange(input.shape[-1]).reshape(1,-1).repeat(input.shape[0], axis=0)
    pos_enc = self.pos_emb(pos)
    word_enc = self.word_emb_src(input)

    x = word_enc + pos_enc
    for i in range(len(self.encoder_layers)):
      x = self.encoder_layers[i](x)
    return x
  
  def decoder(self, input, encoding, mask):
    pos = jnp.arange(input.shape[-1]).reshape(1,-1).repeat(input.shape[0], axis=0)
    pos_enc = self.pos_emb(pos)
    word_enc = self.word_emb_trg(input)

    x = word_enc + pos_enc
    for i in range(len(self.decoder_layers)):
      x = self.decoder_layers[i](x, encoding, mask)
    x = self.fc_out(x)
    return x

  여기에서는 위에서 못보았던 것들이 존재합니다. setup함수에서 pytorch 처럼 우리가 정의한 nn.Module의 함수를 layer처럼 가져와서 변수에 담을 수 있다는 사실을 한가지 배울 수 있습니다.

 

  그리고, __call__함수에서 모든 data flow를 작성해야 하는 것이 아니라 이 역시도 pytorch 처럼 다른 함수를 class 에서 정의해주고 가져와서 사용하는 것이 가능하다는 것입니다. 

 

 그리고, 이 부분에서는 layer를 __call__ 에서 정의해주지 않았기 떄문에 @nn.compact decorator를 사용하지 않았다는 것만 주의하고 다시 살펴보시면 쉽게 이해하실 수 있을 것이라고 생각됩니다.  

 

 여기까지가 vanila transformer를 JAX 기반 ML 라이브러리인 Flax를 통해서 구현한 부분입니다. 

 

  생각보다 pytorch를 구성하는 것과 큰 차이가 없이 구성할 수 있다는 것에서 저는 처음 공부할 때 놀랐었습니다. 오히려 익숙해지면 편할 것도 같구요. 하지만 다소 training loop등을 구성하는 부분에서 저에게는 불편함이 있었는데 다음 포스팅에서 다루도록 하겠습니다. 

 

References


[1] https://github.com/Bhavnicksm/vanilla-transformer-jax/

 

GitHub - Bhavnicksm/vanilla-transformer-jax: JAX/Flax implimentation of 'Attention Is All You Need' by Vaswani et al. (https://a

JAX/Flax implimentation of 'Attention Is All You Need' by Vaswani et al. (https://arxiv.org/abs/1706.03762) - GitHub - Bhavnicksm/vanilla-transformer-jax: JAX/Flax implimentation of '...

github.com