Pruna and TritonServer
Pruna + Triton Workflow Overview
Here’s a high-level view of the process:
Prepare the Model: Use Pruna to optimize your machine learning model.
Integrate with Triton: Deploy the optimized model in Triton using its flexible Python backend.
Deploy and Test: Run Triton with your optimized model and validate performance with a client script.
Example: Deploying Stable Diffusion with Pruna and Triton
Let’s break down how to deploy an optimized Stable Diffusion model using Pruna and Triton, which is based on the following example GitHub repository.
Step 1: Preparing the Environment
Before getting started, ensure you have the following installed:
Docker: It is needed to run the Triton Inference Server.
Python with version 3.8 or higher: It is needed to work with Pruna and the Triton Client Library.
Triton Client Library: Install it with
pip install tritonclient[grpc]
.
Step 2: Build the Triton + Pruna Docker Image
Create a Dockerfile
to build an image that includes Triton Server, Pruna, and all required dependencies. You can find a full example of a Dockerfile
here. In the Dockerfile
, we achieve the following steps:
Start with NVIDIA's Triton Server base image.
Install Pruna with GPU support (
pruna[gpu]
).Add any necessary Python libraries for your model (e.g., PyTorch, diffusers, transformers…).
Build the image:
docker build -t tritonserver_pruna .
Note: You can check out the full Dockerfile example here.
Step 3: Configure the Model for Triton
Triton uses a model repository to manage models. In this tutorial, we serve the Stable Diffusion model, as shown in the directory structure. Note that you can adapt this structure to add other models if you need to serve more.
model_repository/ └── stable_diffusion/ ├── config.pbtxt └── 1/ └── model.py
Model Configuration (config.pbtxt
)
The config.pbtxt
file defines the input-output interface and GPU settings for the model. We provide a full example of the config.pbtxt
here. For Stable Diffusion, the configuration might look like this:
Inputs: A single string (text prompt).
Outputs: A 512x512 image with 3 color channels.
Batch Size: Supports up to 4 simultaneous requests.
If you have different input and output types for your model, you can easily adapt the config.pbtxt
with the tritonserve-torch docs here.
Python Backend Implementation (model.py
)
The model.py
file handles the model's loading and inference logic. With Pruna, you can integrate optimizations like step caching to reduce computation time. The key steps are:
Load the Stable Diffusion pipeline.
Apply Pruna’s step caching compiler with your token.
Define the Triton inference workflow.
Refer to the config.pbtxt and model.py in the repository for a complete example.
Step 4: Run the Triton Server
Once the model repository is ready, run Triton with your Docker container:
docker run --rm --gpus=all -p 8000:8000 -p 8001:8001 -p 8002:8002 \\ -v "path/to/your/model_repository:/models" \\ tritonserver_pruna tritonserver --model-repository=/models
Here are some details on the meaning of the parameters of the command line:
--rm: Removes the container once it stops.
--gpus=all: Enables GPU acceleration, using all available GPUs.
-p 8000:8000 -p 8001:8001 -p 8002:8002: Exposes port 8000 for HTTP and gRPC inference requests and/or model repository control API.
-v "/absolute/path/to/your/model_repository:/models": Mounts the model repository directory to /models inside the container. Make sure to replace
path/to/your/model_repository
with the actual path to your model repository.tritonserver_pruna: The name of the Docker image being used.
tritonserver --model-repository=/models: Runs Triton and specifies the directory where models are stored.
Step 5: Run the Client Script
With the server running, use the tritonclient
Python library to send a request. The following example script sends text prompts to the stable_diffusion
model appearing in the directory structure and retrieves the generated images.
from tritonclient.grpc import InferenceServerClient, InferInput # Connect to Triton Server client = InferenceServerClient(url="localhost:8001") # Prepare the input input_text = np.array(["a serene mountain view"], dtype=object).reshape(-1, 1) input_tensor = InferInput("INPUT_TEXT", input_text.shape, "BYTES") input_tensor.set_data_from_numpy(input_text) # Perform inference response = client.infer(model_name="stable_diffusion", inputs=[input_tensor]) output_data = response.as_numpy("OUTPUT") print(f"Generated image: {output_data}")
🎉That's it. You now have a working example of using Pruna AI with Triton server!🎉