Skip to content

Commit

Permalink
Fix service name adding
Browse files Browse the repository at this point in the history
  • Loading branch information
hjk1030 committed May 15, 2024
1 parent 700301b commit 7e04e5b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
18 changes: 13 additions & 5 deletions aidb/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def execute(self, query: str, **kwargs):
async def clear_ml_cache(self, service_name_list = None):
'''
Clear the cache and output table if the ML model has changed.
For each cached inference service, build the reference graph of the tables based on fk constraints,
and then delete the tables following the graph's topological order to maintain integrity during deletion.
Delete the tables following the inference services' topological order to maintain integrity during deletion.
service_name_list: the name of all the changed services. A list of str or None.
If the service name list is not given, the output for all the services will be cleared.
'''
Expand All @@ -63,13 +62,22 @@ async def clear_ml_cache(self, service_name_list = None):
if service_name_list is None:
service_name_list = [bounded_service.service.name for bounded_service in service_ordering]
service_name_list = set(service_name_list)

# Get all the services that need to be cleared because of foreign key constraints
for bounded_service in service_ordering:
if bounded_service.service.name in service_name_list:
for in_edge in self._config.table_graph.in_edges(bounded_service.service.name):
service_name_list.add(in_edge[0])

# Clear the services in reversed topological order
for bounded_service in reversed(service_ordering):
if isinstance(bounded_service, CachedBoundInferenceService):
if bounded_service.service.name in service_name_list:
for input_column in bounded_service.binding.input_columns:
service_name_list.add(input_column.split('.')[0])
asyncio_run(conn.execute(delete(bounded_service._cache_table)))
output_tables_to_be_deleted = set()
for output_column in bounded_service.binding.output_columns:
asyncio_run(conn.execute(delete(bounded_service._tables[output_column.split('.')[0]]._table)))
output_tables_to_be_deleted.add(output_column.split('.')[0])
for table_name in output_tables_to_be_deleted:
asyncio_run(conn.execute(delete(bounded_service._tables[table_name]._table)))
else:
logger.debug(f"Service binding for {bounded_service.service.name} is not cached")
2 changes: 0 additions & 2 deletions tests/tests_caching_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ async def test_num_infer_calls(self):
del aidb_engine
p.terminate()
p.join()
time.sleep(1)

async def test_only_one_service_deleted(self):
'''
Expand Down Expand Up @@ -160,7 +159,6 @@ async def test_only_one_service_deleted(self):
del aidb_engine
p.terminate()
p.join()
time.sleep(1)

if __name__ == '__main__':
unittest.main()

0 comments on commit 7e04e5b

Please sign in to comment.