본문 바로가기

MLOps

NVIDIA Triton 한 눈에 알아보기

NVIDIA Triton 서버는 오픈소스 소프트웨어로 제공되는 머신러닝 모델 inference 서버다. 학습된 모델 파일을 model repository에 저장하면 개발자가 별도의 코드를 작성할 필요 없이 해당 모델을 추론하는 API를 만들 수 있다.

Triton 백엔드란 모델을 실행하는 C 구현체를 일컫는 말로 TensorFlow, PyTorch 등 주요 머신러닝 프레임워크의 wrapper가 될 수도 있고 커스텀 C/C++ 코드가 될 수도 있다. TensorRT, Tensorflow, PyTorch, ONNX Runtime, Python 등의 백엔드가 디폴트로 제공된다. Triton의 아키텍처를 한 번 살펴보자.

 

Triton은 model repository라는 저장소에서 서빙할 모델 파일들을 불러온다. 클라이언트는 HTTP/REST 및 GRPC 프로토콜 혹은 Triton에 내장된 C API를 통해 모델 추론 요청을 보낼 수 있다. Triton은 해당 모델의 정책에 따라 요청을 batching 및 scheduling한 후 추론 연산을 수행한다. 추론 결과는 다시 클라이언트에게 반환된다.

Triton으로 모델을 서빙하기 위하여 개발자가 해야할 일은 다음과 같다. 

  1. 서버를 세팅한다. 필자의 경우 AWS ec2 인스턴스에 세팅하였다. 다음 설정 내용은 GPU 타입의 인스턴스인 g4dn.xlarge 기준이다.
    • 인스턴스 타입은 g4dn.xlarge 인스턴스로, OS 종류는 Ubuntu 20.04, 디스크 볼륨은 Triton의 최소 요구사항인 100GB 이상으로 설정한다.
    • GPU 타입의 인스턴스에서 머신러닝 작업을 하기 위해 필요한 AMI(Amazon Machine Image)를 설치한다. 필자는 ami-00a79cef350bf9ca5가 설치된 ec2 인스턴스에서 테스트하였다.
  2. Model repository를 생성하고 서빙할 모델 파일들을 여기에 저장한다.
  3. Nvidia cloud에서 제공하는 Triton 서버 이미지 nvcr.io/nvidia/tritonserver:xx.xx-py3를 서버에서 docker run한다.
    • xx.xx는 릴리즈 버전으로 필자는 22.12를 선택했다.
    • Run 커맨드 : docker run -it --rm --gpus=1 --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 (로컬 model repository를 사용할 경우 : -v <로컬 경로>:/opt/tritonserver/model_repository) -p 8000:8000 -p 8001:8001 -p 8002:8002 nvcr.io/nvidia/tritonserver:xx.xx-py3
  4. Triton 서버를 launch한다.
    • Launch 커맨드 : tritonserver <옵션1> <옵션2> ...
  5. Triton 서버에 추론 요청을 보낸다.

 

Model Repository

Model repository는 서빙할 모델 파일들을 저장하는 곳을 말한다. Triton 서버의 local path일 수도 있고 GCP, AWS S3, Azure 등의 클라우드 저장소도 가능하다. Triton 서버를 launch할 때 --model-repository 옵션으로 repository로 사용할 디렉토리를 지정하게 된다. --model-repository 옵션을 여러 번 사용하여 여러 개의 model repository를 지정할 수 있다.

tritonserver --model-repository=<local-repository-path> --model-repository=s3://<s3-repository-path> ...

 

Model repository는 반드시 아래와 같은 layout을 갖춰야한다.

<model-repository-path>/
    <model-1-name>/
        [config.pbtxt]
        [<output-labels-file> ...]
        <version>/
            <model-definition-file>
        <version>/
            <model-definition-file>
        ...
    <model-2-name>/
        [config.pbtxt]
        [<output-labels-file> ...]
        <version>/
            <model-definition-file>
        <version>/
            <model-definition-file>
        ...
    ...

 

Model repository의 각 하위 디렉토리는 특정 모델에 대한 정보를 담고 있다. 디렉토리 이름이 곧 모델명이 된다. config.pbtxt는 모델 configuration이 담긴 파일로 밑에서 더 자세히 다룰 것이다. 각 모델 폴더 안에는 모델 version number를 이름으로 하는 폴더가 반드시 하나 이상 존재해야 한다. 숫자가 아니거나 0으로 시작하는 폴더는 건너뛰도록 되어있다. 각 버전 폴더마다 해당 버전에 해당하는 모델 파일이 들어있다. 이 모델 파일이 학습된 모델 아키텍처와 파라미터들이 저장된 결과물이다. 프레임워크에 따라 모델 파일 형식이 다르다.

  • TensorRT : model.plan
  • ONNX : model.onnx
  • Tensorflow : model.graphdef OR model.savedmodel 이름의 폴더
  • TorchScript : model.pt
  • Python : model.py

이해를 돕기 위해 ONNX 모델인 "model1", TensorFlow 모델인 "model2"와 Python 모델인 "model3"가 들어있는 예시 model repository의 layout을 그려봤다.

<model-repository-path>/
    model1/
        config.pbtxt
        1/
            model.onnx
        2/
            model.onnx
    model2/
        config.pbtxt
        1/
            model.savedmodel/
            	saved_model.pb
                variables/
                	variables.index
                        variables.data-00000-of-00001
    model3/
    	config.pbtxt
        1/
        	model.py
        2/
        	model.py
        3/
        	model.py

 

Model Management

Triton 서버에 모델을 load/unload하기 위한 model management API가 제공된다. Model repository에서 Triton 서버로 모델을 load하는 정책은 총 세가지가 있다 : NONE, EXPLICIT, POLL.

NONE

Triton 서버를 launch할 때 model repository에 있는 모든 모델들을 load한다. 서버가 가동 중일 때는 model repository에 변경 사항이 생기거나 model management API로 요청을 보내도 무시된다.

tritonserver --model-repository=<model-repository-path> --model-control-mode=none

 

EXPLICIT

Triton 서버가 launch할 때 --load-model 옵션으로 명시된 모델들만 load한다. Model repository에 있는 모든 모델들을 load하고 싶으면 --load-model=*로 하면 된다. --load-model 옵션이 없으면 아무 모델도 load되지 않는다.

tritonserver --model-repository=<model-repository-path> --model-control-mode=explicit --load-model=model1 --load-model=model2

서버를 띄운 후에는 model management API를 이용하여 모델을 load/unload할 수 있다.

 

POLL

서버를 처음 launch할 때 model repository에 있는 모든 모델들을 load한다. 그 다음 서버는 주기적으로 model repository의 변경 사항을 감지하여 그에 따라 모델을 load/unload한다. 주기는 --repository-poll-secs 옵션으로 설정할 수 있다. Poll mode에서는 model management API를 사용할 수 없다.

tritonserver --model-repository=<model-repository-path> --model-control-mode=poll --repository-poll-secs=60

Poll mode의 경우 model repository의 변경 사항이 즉각 반영되지 않고 주기적으로 반영되기 때문에 프로덕션 환경에서는 권장되지 않는다.

 

Model Configuration

Model repository의 각 모델마다 config.pbtxt라는 파일에 모델 configuration을 작성하게 된다. 필수적으로 작성해야되는 항목은 backend(사용 프레임워크/백엔드), max_batch_size, input 및 output이다.

  • backend : 사용 프레임워크/백엔드 (ex. "tensorrt", "python", "onnxruntime", "tensorflow", "pytorch")
  • platform : 모델 포맷 (ex. "tensorrt_plan", "tensorflow_savedmodel", "tensorflow_graphdef")
  • max_batch_size : 클라이언트 혹은 Triton이 구성할 수 있는 batch size의 최댓값이다. Batching을 원하지 않을 경우 이 값을 0으로 하면 된다.
  • input : 입력 tensor의 이름, data type 및 shape
  • output : 출력 tensor의 이름, data type 및 shape

다음은 간단한 config.pbtxt의 예시다.

backend: "onnxruntime"
max_batch_size: 8
input [
  {
    name: "input0"
    data_type: TYPE_FP32
    dims: [ 224, 224, 3 ]
  }
]
output [
  {
    name: "output0"
    data_type: TYPE_FP32
    dims: [ 1000 ]
  }
]

 

max_batch_size가 0보다 큰 값이기 때문에 input과 output 텐서에는 batch dimension이 첫 번째 dimension으로 붙는다. 즉, 모델의 input tensor의 shape은 [-1, 224, 224, 3], output tensor의 shape은 [-1, 1000]가 된다고 보면 된다. 반면 max_batch_size가 0이라면 batching을 안 하겠다는 뜻이기 때문에 모델은 shape [224, 224, 3]의 input, shape [1000]의 output을 기대하게 된다.

그 외에 필수가 아닌 항목들을 몇 개 적어봤다.

  • name : 모델 이름. Model repository에서 해당 모델의 폴더명과 일치해야한다. 굳이 안 적어도 알아서 폴더명을 모델 이름으로 인식한다.
  • version_policy : 명시하지 않을 경우 해당 모델의 마지막 (가장 큰) 버전을 load한다. 
    • { all { }} : 모든 버전들을 load한다.
    • { latest { num_versions n}} : 마지막 n개의 버전을 load한다.
    • { specific { versions [n1n2]}} : 특정 버전들만 load한다.
  • instance_group : 모델이 실행될 GPU/CPU를 지정하고 각 디바이스 당 모델 instance의 개수를 지정한다.
  • cc_model_filenames : 각 모델 파일과 CUDA 버전을 매핑할 수 있다.
  • dynamic_batching : Dynamic batching을 enable하고 관련 파라미터들을 설정할 수 있다.

 

Concurrent Model Execution

Triton 서버에서는 여러 모델을 병렬로 실행할 수 있고 같은 모델의 instance를 여러 개 만들어서 병렬로 실행할 수 있다. 우선 서로 다른 모델에 각각 요청이 동시에 들어왔을 경우 두 모델 모두 병렬로 연산을 수행하게 된다.

그렇다면 하나의 모델에 여러 요청이 들어오면 어떻게 될까?

Default 설정으로 각 GPU마다 각 모델의 instance를 하나씩 배치하도록 되어있다. 만약 GPU가 하나인 상황에서 특정 모델에 여러 요청이 동시에 들어온다면 Triton은 이들을 직렬화하여 한번에 하나씩 순차적으로 처리할 것이다.

 

사용자는 config.pbtxt의 instance_group 항목을 조정하여 각 GPU마다 모델의 instance 수를 변경할 수 있다. 몇 가지 예시를 보자.

instance_group [
  {
    count: 3
    kind: KIND_GPU
  }
]

 

위와 같이 설정하면 Triton은 모든 GPU에 대하여 각 GPU에 해당 모델의 instance를 3개씩 만든다. 각 GPU마다 해당 모델에 들어온 요청을 최대 3개씩 병렬로 처리할 수 있는 것이다.

instance_group [
  {
    count: 1
    kind: KIND_GPU
    gpus: [0]
  },
  {
    count: 2
    kind: KIND_GPU
    gpus: [1, 2]
  }
]

 

위와 같이 GPU마다 설정을 다르게 할 수 있다. 0번 GPU에는 1개의 instance, 1, 2번 GPU에는 instance를 각각 2개씩 만든 case다.

GPU 말고 CPU를 사용하는 것도 가능하다. 아래 예시에서는 CPU에 해당 모델에 대하여 2개의 실행 인스턴스를 생성한다.

instance_group [
  {
    count: 2
    kind: KIND_CPU
  }
]

 

ONNX Runtime, TensorFlow, PyTorch 등 대부분의 백엔드의 경우 멀티 쓰레딩을 통해 병렬 추론을 한다. 그러나 Python의 경우 GIL로 인해 여러 thread가 병렬적으로 실행되는 것이 불가능하다. 그래서 Python 백엔드의 경우 멀티 프로세싱을 통해 병렬 실행을 한다. 각 model instance가 하나의 개별 process고 당연히 서로 다른 모델은 다른 process에 의해 서빙된다. Python 백엔드를 사용할 때는 이점을 고려하여 instance 수를 결정해야한다.

 

Dynamic Batching

Dynamic batching은 한 모델로 들어온 여러 요청을 배치로 묶어서 한번에 처리하는 기능이다. 모델의 input shape이 [224, 224, 3]이고 8개의 요청이 비슷한 시각에 들어왔다면 이를 [8, 224, 224, 3]의 array로 stack하여 모델이 한번에 추론하는 것이다. 이 기능을 적절히 사용하면 throughput을 크게 향상시킬 수 있다.

config.pbtxt에서 max_batch_size 값을 통해 배치의 최대 크기를 지정할 수 있다. max_batch_size 다음에 신경써야할 파라미터는 max_queue_delay_microseconds다. Triton의 scheduler는 배치 크기가 max_batch_size에 도달했거나 대기 시간이 max_queue_delay_microseconds에 도달한 요청이 발생하면 바로 추론에 들어가도록 한다. 이 두 조건이 충족되기 전까지는 다음 요청을 기다린다.

dynamic_batching {
  max_queue_delay_microseconds: 100
}

 

Ensemble Model

Ensemble 모델은 여러 모델의 파이프라인을 말한다. 한 모델의 output tensor가 다른 모델의 input tensor가 되면서 여러 모델이 연결되는 형태다. 예를 들면, 하나의 추론 요청에 대하여 전처리 -> 딥러닝 모델 추론을 해야 한다면 이 두 단계를 하나의 ensemble 모델로 캡슐화할 수 있다. 그러면 Triton은 전처리 단계의 input tensor을 요청으로 받아서 딥러닝 모델 추론 결과를 한방에 반환할 것이다. 이렇게 하면 각 단계가 끝날 때마다 중간 output을 받고 다음 단계를 위해 Triton에 다시 요청을 보내는 번거로움을 없앨 수 있다. 예시 ensemble 모델의 configuration 파일을 한번 살펴보자.

platform: "ensemble"
max_batch_size: 1
input [
  {
    name: "IMAGE"
    data_type: TYPE_FP32
    dims: [ 3, 224, 224 ]
  }
]
output [
  {
    name: "CLASSIFICATION"
    data_type: TYPE_FP32
    dims: [ 1000 ]
  }
]
ensemble_scheduling {
  step [
    {
      model_name: "image_preprocess_model"
      model_version: -1
      input_map {
        key: "RAW_IMAGE"
        value: "IMAGE"
      }
      output_map {
        key: "PREPROCESSED_OUTPUT"
        value: "preprocessed_image"
      }
    },
    {
      model_name: "classification_model"
      model_version: -1
      input_map {
        key: "FORMATTED_IMAGE"
        value: "preprocessed_image"
      }
      output_map {
        key: "CLASSIFICATION_OUTPUT"
        value: "CLASSIFICATION"
      }
    }
  ]
}

 

"image_preprocessed_model"의 output_map의 value 값와 "classification_model"의 input_map의 value 값이 일치한다. 전처리를 먼저 수행하고 다음에 분류 모델에 대하여 추론해야한다는 것을 알 수 있다. 또한 ensemble 모델의 최종 input과 output이 각각 전처리 모델의 input, 분류 모델의 output과 이름이 일치하는 것을 확인할 수 있다.

 

Triton 서버에 추론 요청 보내기

Model repository를 구축하고 Triton 서버를 launch하는 것까지 완료했다고 하자. Launch할 때 발생한 로그를 통해 필요한 모델들의 Status가 "READY"인지 확인하여 정상적으로 서빙 가능한지 확인할 수 있다. 더불어 GRPC Service는 8001번, HTTP Service는 8000번, Prometheus Metrics Service는 8002번 포트에 뜬 것을 확인할 수 있다.

이제 health check를 해서 서버가 정상적으로 통신되는지 확인해보자.

curl -v <ec2 인스턴스 ip>:8000/v2/health/ready

 

200 응답이 와야한다. 이번엔 jupyter notebook에서 이름이 "model1"인 모델에 추론 요청을 보내자.

import requests
import json
import numpy as np

url = 'http://<ec2 인스턴스 ip>:8000/v2/models/model1/versions/1/infer'

array = np.zeros([1,224,224,3]).tolist()
payload = {
	"inputs": [
    	{
            "name": "input0",
            "datatype": "FP32",
            "shape": [1,224,224,3],
            "data": array
        }
    ]
}

response = requests.post(url, data=json.dumps(payload))
print(response.text)

 

Input의 name, datatype과 shape은 config.pbtxt를 보면 알 수 있다.

만약에 해당 모델의 버전이 한 개밖에 없으면 url에서 버전은 생략해도 된다.

  • http://<ec2 인스턴스 ip>/v2/models/model1/infer

이렇게 단순히 Python requests로 요청을 보내면 요청 array에 batch dimension을 만들어줘야 되고 응답 array가 전부 flatten되어 나오는 등의 불편함이 있다. 또한 url과 요청 바디를 구성하는 것이 번거로울 수 있다. 그래서 tritonclient 라이브러리를 설치하고 이를 통해서 서버에 요청을 보내는 것이 권장된다.

pip install tritonclient			# 로컬 요청만 가능
pip install tritonclient[http]			# http 클라이언트
pip install tritonclient[grpc]			# grpc 클라이언트
pip install tritonclient[all]			# http, grpc 클라이언트

 

tritonclient로 http, grpc 요청을 보내는 예시 Python 코드를 참고하자.

 

perf_analyzer

Triton은 perf_analyzer라는 매우 편리한 성능 테스트기를 제공한다. 

안정적으로 성능 테스트를 진행하기 위해서는 클라이언트 서버를 따로 구축하는 것이 좋다. 필자는 SageMaker ml.m5.2xlarge 노트북 인스턴스에 구축했다. 노트북 터미널에서 다음과 같이 Nvidia 클라우드에서 제공하는 Triton 클라이언트 이미지를 실행한다.

docker run -it --rm --net=host nvcr.io/nvidia/tritonserver:22.12-py3-sdk

 

그 다음 "model1"에 대한 성능 테스트를 수행한다.

perf_analyzer -m model1 -u <ec2 인스턴스 ip>:8000 --shape input0:224,224,3 --concurrency-range 1:100:5 -f result.csv --input-data zero

 

옵션 설명

  • -m
    • 테스트하고자 하는 모델 이름
  • -u
    • Triton 서버 주소
  •  --concurrency-range
    • 숫자인 경우 : 동시 요청 수
      • ex. --concurrency-range 10
    • 범위인 경우 : 시나리오
      • ex. --concurrency-range 1:100:5
      • 예시의 경우 동시 요청 수를 1부터 시작해서 100이 되기 전까지 5씩 늘려가겠다는 것을 의미한다.
      • 각 concurrency level마다 time window 동안 avg latency가 안정될 때까지 기다린 후 자동으로 다음 단계로 넘어간다.
  • --shape
    • input tensor의 shape
    • input이 여러 개인 경우 이 옵션을 여러 번 사용하면 된다.
    • 형식 : -f <input name>:<input shape>
    • Input tensor의 shape이 fixed인 경우 이 옵션이 필요 없다. config.pbtxt에서 해당 input의 shape에 -1이 없으면 fixed인 것이다.
  • -f
    • 결과 저장 파일 이름
  • --input-data
    • random : random input tensor (default)
    • zero : zeros input tensor
    • json 파일 경로 : custom input

 

결과

각 concurrency level에서의 throughput(RPS, # of Inferences/sec)과 latency(avg, p50, p90, p95, p99)를 출력할 수 있다.

 

참고 자료

https://github.com/triton-inference-server/server 

 

GitHub - triton-inference-server/server: The Triton Inference Server provides an optimized cloud and edge inferencing solution.

The Triton Inference Server provides an optimized cloud and edge inferencing solution. - GitHub - triton-inference-server/server: The Triton Inference Server provides an optimized cloud and edge i...

github.com

https://github.com/triton-inference-server/client 

 

GitHub - triton-inference-server/client: Triton Python, C++ and Java client libraries, and GRPC-generated client examples for go

Triton Python, C++ and Java client libraries, and GRPC-generated client examples for go, java and scala. - GitHub - triton-inference-server/client: Triton Python, C++ and Java client libraries, and...

github.com

 

 

'MLOps' 카테고리의 다른 글

Triton Python Backend 사용하기  (1) 2023.05.14
AWS Serverless 2편  (1) 2023.01.05
AWS Serverless 1편  (0) 2023.01.02
Amazon SageMaker  (0) 2022.04.25
Multi-Armed Bandit with Seldon Core  (0) 2022.02.20