|
| 1 | +# |
| 2 | +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +# |
| 16 | +from abc import ABC |
| 17 | + |
| 18 | +import pandas as pd |
| 19 | +from peewee import MySQLDatabase, PostgresqlDatabase |
| 20 | +from agent.component.base import ComponentBase, ComponentParamBase |
| 21 | + |
| 22 | + |
| 23 | +class ExeSQLParam(ComponentParamBase): |
| 24 | + """ |
| 25 | + Define the ExeSQL component parameters. |
| 26 | + """ |
| 27 | + |
| 28 | + def __init__(self): |
| 29 | + super().__init__() |
| 30 | + self.db_type = "mysql" |
| 31 | + self.database = "" |
| 32 | + self.username = "" |
| 33 | + self.host = "" |
| 34 | + self.port = 3306 |
| 35 | + self.password = "" |
| 36 | + self.loop = 3 |
| 37 | + self.top_n = 30 |
| 38 | + |
| 39 | + def check(self): |
| 40 | + self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgresql', 'mariadb']) |
| 41 | + self.check_empty(self.database, "Database name") |
| 42 | + self.check_empty(self.username, "database username") |
| 43 | + self.check_empty(self.host, "IP Address") |
| 44 | + self.check_positive_integer(self.port, "IP Port") |
| 45 | + self.check_empty(self.password, "Database password") |
| 46 | + self.check_positive_integer(self.top_n, "Number of records") |
| 47 | + |
| 48 | + |
| 49 | +class ExeSQL(ComponentBase, ABC): |
| 50 | + component_name = "ExeSQL" |
| 51 | + |
| 52 | + def _run(self, history, **kwargs): |
| 53 | + if not hasattr(self, "_loop"): |
| 54 | + setattr(self, "_loop", 0) |
| 55 | + if self._loop >= self._param.loop: |
| 56 | + self._loop = 0 |
| 57 | + raise Exception("Maximum loop time exceeds. Can't query the correct data via sql statement.") |
| 58 | + self._loop += 1 |
| 59 | + |
| 60 | + ans = self.get_input() |
| 61 | + ans = "".join(ans["content"]) if "content" in ans else "" |
| 62 | + if not ans: |
| 63 | + return ExeSQL.be_output("SQL statement not found!") |
| 64 | + |
| 65 | + if self._param.db_type in ["mysql", "mariadb"]: |
| 66 | + db = MySQLDatabase(self._param.database, user=self._param.username, host=self._param.host, |
| 67 | + port=self._param.port, password=self._param.password) |
| 68 | + elif self._param.db_type == 'postgresql': |
| 69 | + db = PostgresqlDatabase(self._param.database, user=self._param.username, host=self._param.host, |
| 70 | + port=self._param.port, password=self._param.password) |
| 71 | + |
| 72 | + try: |
| 73 | + db.connect() |
| 74 | + query = db.execute_sql(ans) |
| 75 | + sql_res = [{"content": rec + "\n"} for rec in [str(i) for i in query.fetchall()]] |
| 76 | + db.close() |
| 77 | + except Exception as e: |
| 78 | + return ExeSQL.be_output("**Error**:" + str(e)) |
| 79 | + |
| 80 | + if not sql_res: |
| 81 | + return ExeSQL.be_output("No record in the database!") |
| 82 | + |
| 83 | + sql_res.insert(0, {"content": "Number of records retrieved from the database is " + str(len(sql_res)) + "\n"}) |
| 84 | + df = pd.DataFrame(sql_res[0:self._param.top_n + 1]) |
| 85 | + return ExeSQL.be_output(df.to_markdown()) |
0 commit comments