본문 바로가기
MLOps

[MLOps, Acceleration, Jax] Google Jax 란?

by ML_MJSHIN 2021. 12. 13.

  당분간은 Google JAX에 대해서 소개하고 이에 대해 활용하는 것과 관련된 간단한 포스팅을 진행하려고 합니다. 최근 공부해서 회사의 ML 엔진에 적용할 기회가 생겨서 작업을 하고 있었는데, 역시 익숙하지 않은 상태로 적용하고 하다보니 스스로 다시 정리할 기회도 만들고 혹시나 저처럼 처음 보시는 분들에게도 도움이 되길 바라며 포스팅을 하려고 합니다.

 

  아마 JAX에 관한 포스팅은 소개부터 시작해서 Google DeepMind의 코드를 리뷰해보는 것들까지 진행하면서 여러개의 포스팅이 될 것 같습니다. 관심있으신 분은 계속 살펴봐주세요 ㅎㅎ

 

 Google JAX


  JAX는 machine learning research를 돕기위해서 개발된 Pytorch와 TensorFlow처럼 automatic differentiation (AutoGrad라고 불리는 기능)을 가지고 있는 CPU, GPU 그리고 TPU에서 동작하는 Numpy framework 입니다.

 

  가장 큰 장점은 Numpy의 연산들을 GPU에서 가능할 수 있게 해주어 기존의 Numpy의 성능을 뛰어넘어 훨씬 빠른 연산 속도를 보여줍니다. 그리고, JIT (just in time) 컴파일 기법고 XLA 컴파일러 (Accelerated Linear Algebra)를 사용하여 컴파일 할 수 있어 런타임에 사용자가 생성한 TensorFlow 그래프를 분석하고 실제 런타임 차원과 유형에 맞게 최적화 및 여러 연산을 함께 합성하고 이에 대한 효율적인 네이티브 기계어 코드를 내보낸다고 합니다 [2].  

 

  하지만 JAX를 활용하기 위해서는 몇가지 주의해야 할 사항들이 공식 docs에 정리가 되어 있습니다. 

 

1. JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs.

 

-> JAX를 작성한 python function 에 적용하고자 하는 경우 pure python function 형태여인 경우만 정상적인 작동이 보장된다는 의미입니다. 

 

  위의 글만으로는 이해가 어렵기 때문에 예시를 보면서 설명하도록 하겠습니다. 아래에 python function이 있으며 이 function은 외부의 global variable을 참조하여 연산을 수행하고 있습니다. 이런 경우는 pure python function이 아닌 경우입니다. pure python function이란 function이 입력이 같으면 output의 같음이 항상 보장되는 함수인데 반해 아래의 함수는 g 라는 외부 변수에 의해서 output이 변경될 수 있기 때문입니다.

 

  만약 아래와 같은 경우 문제가 생기는 경우는 다음과 같습니다. 저희는 아래와 같은 코드를 구성하였을 때 다음과 같은 결과를 기대하였습니다.

 - First call : '4', Second call  : '15', Third call : ['14']

 

  하지만 실제로 JAX를 사용해서 아래의 결과를 시행해보면 다음과 같은 결과가 반환됩니다.

 - First call : '4', Second call : '5', Third call : ['14'] 

 

  이는 JAX는 기본적으로 함수가 pure function일 것을 가정하고 함수에서 사용된 변수들을 cache 처리하고 연산을 수행합니다. 이때문에 Second call에서는 g 값을 cache 로 저장해두어 5 + 0의 연산을 수행하게 됩니다. 그러나 마지막의 경우 함수의 입력 인자의 shape가 float 에서 array 로 변환되면서 JAX가 함수에 대한 re-compile의 필요성을 인식하게되고 다시 compile을 수행하게 되면서 변환된 g의 값을 가져와서 cache를 수행하게되므로 14 라는 결과값이 나오게 됩니다. 

g = 0.
def impure_uses_globals(x):
  return x + g

print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global

print ("Second call: ", jit(impure_uses_globals)(5.))

print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))

 

2. It is not recommended to use iterators in any JAX function you want to jit or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results.

 

-> JAX 함수에서는 python object인 iterator를 사용하면 기대하지 않은 결과가 나올 수 있다는 것을 설명하고 있습니다.

  이 부분도 역시 예시와 함께 살펴보도록 하겠습니다. 0~9 까지의 integer로 구성된 array를 만들고 더하는 JAX function을 사용할때의 결과를 비교하는 예시입니다. 첫번째 경우에는 jax의 range 함수를 사용해서 array를 만들엇기 때문에 python iterator가 아닌 array 가 존재하여 기대하는 45라는 결과를 출력합니다. 반면, python iterator로 똑같이 array를 선언한다면 0 이라는 기대하지 않은 결과를 출력하게 됩니다. 

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

 

3. JAX device array에 대한 In-Place Updates가 불가능 합니다. 아래의 예시에서는 set 만을 설명하고 있지만 'at' 관련 operation은 다양하게 존재합니다 (https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at).

jax_array = jnp.zeros((3,3), dtype=jnp.float32) # JAX Device array 선언 

try:
  jax_array[1, :] = 1.0 # In-Place Update는 오류를 발생시킵니다.
  # jax_array.at[1, :].set(1.0) 을 사용해야함
except Exception as e:
  print("Exception {}".format(e))

 

4. python list or tuple은 JAX function의 입력으로 넣을 경우 error를 반환합니다. 이 부분은 JAX에서 Python list, tuple의 element들을 각각 다른 JAX 변수로써 처리하기 때문에 performance 저하를 불러일으키기 때문에 방지된 사용법 입니다. 그러므로, 효과적인 JAX 사용을 위해서 list 와 tuple은 JAX array혹은 numpy로 변환하여 전달해주어야 합니다. 

try:
  jnp.sum([1, 2, 3])
  # jnp.sum(jnp.array([1,2,3]))
except TypeError as e:
  print(f"TypeError: {e}")

 

  기본적인 4가지 주의사항 이외에도 사용법에 대한 주의 사항이 존재하지만 중요한 포인트들은 위에 정리한 것들인 것 같습니다. (https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers)

 

 

Pytorch to JAX


  저는 회사에서 Pytorch 로 ML 엔진을 구축해서 작업을 하고 있기 때문에 실질적으로 Pytorch 와 JAX의 호환성과 JAX로 변환하는 과정에 대해서 관심이 많이 있었습니다. 그리고, [3] 의 사이트에서 이에 대해서 상세히 설명을 해주고 있었기에 시작하는데 어려움이 없었기에 해당 포스트의 내용을 리뷰하면서 이번 JAX 소개 포스팅을 마무리하려고 합니다. 다음 포스팅 부터는 실제 예제 코드와 함께 정리를 해볼 생각입니다.

 

 우선 Pytorch 에서 JAX 로의 변환 작업은 기본적으로 ML 을 위한 연산 구성과 backpropagation 등의 방법에서 생각보다 큰 차이가 존재합니다.  

https://sjmielke.com/jax-purify.htm

 Pytorch를 통해서 ML 모델의 연산을 구성하는 경우를 생각해보겠습니다. 저희는 Forward pass를 따라서 수식을 코드로 변환하고 ML 모델 graph를 만든 뒤 필요할 때 backward 함수를 호출합니다. Pytorch는 이렇게 backward 함수가 호출되었을 때 autodiff 가 graph를 따라 gradient를 전파하고 node들을 업데이트 합니다.

 

 반면, JAX는 Forward pass를 따라서 연산을 코드로 구현하는 것으로 자동으로 graph가 구성되는 형태가 아닙니다. 대신 필요한 연산별로 python function을 만들어야 합니다. 이때, 정의한 python function을 jax.grad 함수에 전달하는 것으로 JAX는 gradient function을 반환해주게되며 gradient_function에 forward pass와 동일한 입력값을 넣어주면 gradient 값을 반환해주게 됩니다. 

 

 문제는 이런 방식이 기존에 Pytorch로 구성하는 방식에서 변환할 떄 굉장한 어려움을 주게 됩니다. 단순하게 생각해보면 Pytorch에서 하나의 ML model을 구성할 때 했던 forward 함수 내부의 연속적인 작업들을 어떻게 함수 단위로 모듈화해서 분리시키는 것이 가능할까 라는 생각부터 들게 됩니다. 

 

 이 부분을 정확히 이해하기 위해서는 위에서 설명드린 Pure functions를 어떻게 JAX에서는 다루는지 자세히 살펴볼 필요가 있습니다.

 

Pure functions


  Pure function에 대한 검색을 해보면 다음과 같은 정의를 찾을 수 있습니다. 오로지 input argument들에 의해서만 output argument가 결정되는 함수를 pure function 이라고 합니다. 

더보기

a pure function is like a function or formula in math. It defines how an output value is obtained from some input values. What's important is that it has no “side effects”: no part of the function should access or even mutate any global state.

  

 쉬운 이해를 위해서 예제를 가져왔습니다. 같은 W*X 를 수행하는 linear regression 모델을 구현해도 pytorch는 아래와 같이 구현을 할 것입니다.

class Temp(nn.Module):
	def __init__(self, dim):
    	self.w = torch.nn.Parameter(torch.rand(1, dim))
        
    def forward(self, x):
    	return self.w * x

 반면 JAX로의 구현은 다음과 같습니다. 이때 JAX는 grad 를 통해서 함수의 gradient 함수를 얻는 경우 입력 함수 (ex., f)의 첫번째 parameter의 측면에서의 gradient를 구합니다. (ex., f(w, x) -> w).

def f(w, x):
    return w * x

# print(f(13., 42.)) # 546.0

import jax
import jax.numpy as jnp

# Gradient: with respect to weights! JAX uses the first argument by default.
df_dw = jax.grad(f)

def manual_df_dw(w, x):
    return x

assert df_dw(13., 42.) == manual_df_dw(13., 42.)

print(df_dw(13., 42.)) # 42

 

 차이가 살짝 보이시나요? ㅎㅎ 아직 이 예제로는 충분한 이해를 할 수 없기 때문에 실제 코드들을 보면서 공부를 하는 포스팅을 이어서 작성하려고 합니다. 우선 이렇게 간단한 소개로 이번 포스팅은 마치고 다음 포스팅에서 구체적인 예제와 함께 살펴보도록 하겠습니다. 

 

 

References


[1] https://jjeamin.github.io/posts/jax/ 

 

Google JAX 라이브러리란?

JAX

jjeamin.github.io

 

[2] https://developers-kr.googleblog.com/2017/03/xla-tensorflow-compiled.html

 

TensorFlow용 컴파일러인 XLA를 소개합니다.

<블로그 원문은 여기 에서 확인하실 수 있으며, 블로그 번역 리뷰는 Justin Hong(Google)님이 참여해 주셨습니다.> 게시자: Google XLA 팀(TensorFlow 팀 공동 게시) TensorFlow의 디자인 목표와 ...

developers-kr.googleblog.com

[3] https://sjmielke.com/jax-purify.htm

 

From PyTorch to JAX: towards neural net frameworks that purify stateful code — Sabrina J. Mielke

From PyTorch to JAX: towards neural net frameworks that purify stateful code 2020-03-09 Update 2021-07-01: I gave a talk at the Flax/JAX community week largely based on this blogpost---but made a bit more concise and punchy and including an example of flax

sjmielke.com