Skip to content

Python SDK Tutorial Part 3

Albert Suarez edited this page Jan 23, 2019 · 4 revisions

Creating a simple rate limited client

Now that we have implemented the minimum necessary to make individual requests, let's enhance it by adding a multi-threaded, concurrent implementation of our python client. Restb.ai's API allows up to a single request per second (rps) by default, but also offers higher concurrency (n requests per second) as a premium service. Regardless, it is a good idea to implement a multi-threaded approach even if accessing our API at 1 rps.

One of the scenarios that must be dealt with now is dealing with the 429 'Too Many Requests' HTTP response code. If your client receives this as a response from our API Gateway server (APIGW), then that means your client has exceeded the allotted concurrency for your account, and your client has to retry the request later.

To begin, let's create a new example client restb/examples/rate_limiting_simple.py:

import json
import time
import multiprocessing.dummy as mp

from restb.sdk import *
from restb.sdk.api import service


# set allotted number of requests per second
# this is defined here as it could be retrieved through some external mechanism
__requests_per_second = 4


# lambda helper for getting current time in millis
now_millis = lambda: int(round(time.time() * 1000))


def image_process_thread(url, client_key, queue, results):
    while True:
        # get image URL entry to process
        entry = None
        try:
            entry = queue.get(block=False)
        except:
            pass
        if entry:
            image_id = entry['id']
            img_url = entry['url']
            model_id = entry['model']
            if img_url == 'shutdown':
                print('thread shutting down')
                break
            params = __PARAMS.copy()  # note the module variables as defined in restb/sdk/__init__.py
            params['client_key'] = client_key
            params['image_url'] = img_url
            params['model_id'] = model_id
            endpoint = __ENDPOINT
            start_time = now_millis()
            resp = service(url=url, endpoint=endpoint, params=params)
            end_time = now_millis()
            msg = '[{http}] thread [{thread}] {msg}'
            if resp.status_code == 200:
                vals = json.loads(resp.text)
                results.put(dict(id=image_id, model=model_id, result=vals['response']))
                print(msg.format(
                    http=resp.status_code,
                    thread=mp.current_process().name,
                    msg='processed request in [' + str(total) + '] ms')
                )
            elif resp.status_code == 429:
                # handle over-rate limit retrying
                print(msg.format(
                    http=resp.status_code,
                    thread=mp.current_process().name,
                    msg='surpassed rate limit, trying again')
                )
                # re-queue entry and try again, then sleep for ideal average time between requests
                queue.put(entry)
                time.sleep(1 / float(__requests_per_second))
        else:
            time.sleep(1)

The above function image_process_thread is designed to be used as a python multiprocessing process or thread. For additional information about python multiprocessing, please refer to their documentation. The general algorithm for the above is as follows:

  1. Continuously loop until signaled to exit.

  2. Within the loop, check for queued units of work to do via the queue object (which is of type multiprocessing.Queue).

  3. If there are not any units of work to process within the queue, sleep for a second and continue to the next loop iteration (thereby restarting from step 2).

  4. If there is an entry in the queue (namely, a combination of image_url and model_id), make the API callout just like in the basic example.

  5. This time, we need to check HTTP response code.

    a. if it's a 429, sleep for an arbitrary amount of time (recommended 1 / n threads).

    b. if it's a 200 success code, process the response. In this example, we simply print out the response and put it into a results queue.

So now we have the implementing code that each thread will run, but we still need to invoke the threads (or processes), so let's add a function to spawn (and clean up) all the necessary threads. Add the following function to our new python script:

def test_api(client_key):

    # 1. create test image data and both processing and result queues
    urls = ['https://demo.restb.ai/images/demo/demo-1.jpg',
            'https://demo.restb.ai/images/demo/demo-2.jpg',
            'https://demo.restb.ai/images/demo/demo-3.jpg',
            'https://demo.restb.ai/images/demo/demo-4.jpg',
            'https://demo.restb.ai/images/demo/demo-5.jpg',
            'https://demo.restb.ai/images/demo/demo-6.jpg']
    queue = mp.Queue()
    image_id = 1
    for url in urls:
        for model in __MODELS:
            queue.put(dict(id=image_id, url=url, model=model))
        image_id += 1
    results = mp.Queue()

    # 2. Pick which API endpoint to use (US vs. EU)
    url = __URL_US

    # 3. Define concurrency specific objects
    # TBD later

    # 4. Spawn processes/threads to process the images in the queue
    pool = []
    for i in range(__requests_per_second):
        # pass in necessary parameters to thread, including client key, etc.
        p = mp.Process(target=image_process_thread,
                       args=(url, client_key, queue, results,
                             lock_stats, counter, avg_req_time, time_start, time_end))
        pool.append(p)
        p.start()

    # 5. clean-up after queue has been processed with "poison pill"
    while not queue.empty():
        # wait for queue to be processed
        time.sleep(1)
    for i in pool:
        # seed shutdown messages / poison pills
        queue.put(dict(id=-1, url='shutdown', model='shutdown'))
    for p in pool:
        # enforce clean shutdown of threads
        p.join()

    # 6. finally, return accumulated results
    return results

The above test_api() function essentially functions as follows:

  1. Create a sample queue of work. This queue consists of image_urls in conjunction with individual solutions to test (currently all 6 applicable solutions for real estate, so with 6 image_urls that makes for 36 entries total).
  2. After setting necessary parameters, spawn as many threads / processes as there are allotted requests per second (e.g. if allotted 5 rps, spawn 5 threads).
  3. This function acts as the main thread, which should sleep until all entries in the queue are processed. Given that this is an example, once all entries in the queue are processed, the main thread will put "poison pills" into the queue, which are simply messages for the processing threads to properly shut down.
  4. The final thing it does is returns the accumulated results queue.

The last missing components are to add a run() function to print out the results and invoke it from the previously created run.py script:

run() function:

def run(client_key):
    output = test_api(client_key)
    print('\n\nFinal results queue:')
    results = {}
    while not output.empty():
        # accumulate differing solution results for an image ID together
        result = output.get()
        if result['id'] not in results:
            results[result['id']] = {result['model']: result['result']}
        else:
            results[result['id']][result['model']] = result['result']
    for i in range(len(results.keys())):
        for k, v in sorted(results[i+1].items()):
            print('[{id}] [{model}] {res}'.format(id=i+1, model=k, res=v))

run.py (below ):

if __name__ == '__main__':
    client_key = 'YOUR_CLIENT_KEY_HERE'
    print('1. running basic example')
    basic.run(client_key)
    print('2. running multipredict example')
    multipredict.run(client_key)
    print('3. running simple rate limiting example')
    rate_limiting_simple.run(client_key)

At this point, we have a functioning, rate-limited client that actually works for both the newer APIGW requests-per-second rate limiting approach as well as the older "concurrent threads" approach. However, we can't glean anything meaningful about this approach from a performance perspective as-is, so let's add some statistics tracking:

image_process_thread:

def image_process_thread(url, client_key, queue, results,
                         lock_stats, counter, avg_req_time, time_start, time_end):
...
            if resp.status_code == 200:
                vals = json.loads(resp.text)
                results.put(dict(id=image_id, model=model_id, result=vals['response']))
                total = end_time - start_time
                print(msg.format(
                    http=resp.status_code,
                    thread=mp.current_process().name,
                    msg='processed request in [' + str(total) + '] ms')
                )
                # increment counter
                lock_stats.acquire()
                counter.value += 1
                avg_req_time.value += total
                if start_time < time_start.value:
                    time_start.value = start_time
                if end_time > time_end.value:
                    time_end.value = end_time
                lock_stats.release()
            elif resp.status_code == 429:
...

test_api:

...
    # 3. Define concurrency specific objects
    # stats objects
    lock_stats = mp.Lock()
    counter = mp.Value('i', 0)
    avg_req_time = mp.Value('f', 0)
    time_start = mp.Value('f', 999999999999999)
    time_end = mp.Value('f', 0)
...
    for i in range(__requests_per_second):
        # pass in necessary parameters to thread, including client key, etc.
        p = mp.Process(target=image_process_thread,
                       args=(url, client_key, queue, results,
                             lock_stats, counter, avg_req_time, time_start, time_end))
        pool.append(p)
        p.start()
...
    # 6. finally, return accumulated results
    total = time_end.value - time_start.value
    print('[{requests}] requests processed in [{seconds}] seconds with average time [{time}] ms, total throughput: [{throughput}] rps'.format(
        requests=counter.value,
        seconds=str(round(total / 1000.0, 1)),
        time=str(round(avg_req_time.value / counter.value, 0)),
        throughput=str(round(counter.value / (total / 1000.0), 2))
    ))
    return results

Note that the stats objects contain several multiprocessing concurrency constructs, including lock_stats which is a Lock object that functions as a mutex, and several Value objects which are objects designed to be shared across multiprocessing processes/threads

At this point, in addition to a lot of logging output, running the restb.examples.run command should yield (in addition to lots of debugging output) some performance numbers like the following:

[36] requests processed in [11.4] seconds with average time [1164.0] ms, total throughput: [3.14] rps