How to submit many concurrent requests to Ray Serve

Ray Serve has the ability to dynamically batch incoming requests to process them in chunks. However, if you try to process many queries with requests.post(), each request will be a blocking call and you won't benefit from batching.

Instead, you want to fire many requests concurrently using asynchronous requests and let Ray Serve buffer and batch process them. You can accomplish this with aiohttp:

import asyncio
import time

import aiohttp
import numpy as np
import requests
from ray import serve
from ray.serve.handle import DeploymentHandle
from starlette.requests import Request

model = lambda x: np.random.rand(len(x))


@serve.deployment
class BatchedModel:
    def __init__(self):
        self.model = model

    @serve.batch(max_batch_size=5, batch_wait_timeout_s=0.1)
    async def process_batch(self, input_data: list[dict]) -> list[float]:
        print(f"Processing batch of size: {len(input_data)}")

        results = model(input_data)
        return results

    async def __call__(self, request: Request):
        input_data = await request.json()
        # Route the request to the batch handler
        return await self.process_batch(input_data)


def main():
    model = BatchedModel.bind()
    _handle: DeploymentHandle = serve.run(model, name="batched-model")

    # Simplified sample input
    sample_input = {"value": 1.0}

    url = "http://127.0.0.1:8000/"

    # --- Test with a single request ---
    print("\n--- Sending single request ---")
    start_time = time.time()
    prediction = requests.post(url, json=sample_input).json()
    end_time = time.time()
    print(f"Time taken: {end_time - start_time:.4f}s")


    # --- Simulate many concurrent requests ---
    print("\n--- Sending 100 concurrent requests ---")

    sample_input_list = [sample_input] * 100

    async def fetch(session, url, data):
        async with session.post(url, json=data) as response:
            return await response.json()

    async def fetch_all(requests: list):
        async with aiohttp.ClientSession() as session:
            tasks = [fetch(session, url, input_item) for input_item in requests]
            responses = await asyncio.gather(*tasks)
            return responses

    start_time_main = time.time()
    responses = asyncio.run(fetch_all(sample_input_list))
    end_time_main = time.time()

    # Note: Responses might vary depending on how requests are batched
    print(f"First response: {responses[0]}")
    print(f"Total time (including client-side async setup overhead): {end_time_main - start_time_main:.4f}s")


if __name__ == "__main__":
    main()

Copyright Ricardo Decal. richarddecal.com