diff --git a/docs/library.mdx b/docs/library.mdx index a1e01097d..228c6d95f 100644 --- a/docs/library.mdx +++ b/docs/library.mdx @@ -222,6 +222,7 @@ Settings: - `use_error_correction_framework`: whether to use the error correction framework. Defaults to `True`. If set to `True`, PandasAI will try to correct the errors in the code generated by the LLM with further calls to the LLM. If set to `False`, PandasAI will not try to correct the errors in the code generated by the LLM. - `max_retries`: the maximum number of retries to use when using the error correction framework. Defaults to `3`. You can use this setting to override the default number of retries. - `custom_whitelisted_dependencies`: the custom whitelisted dependencies to use. Defaults to `{}`. You can use this setting to override the default custom whitelisted dependencies. You can find more information about custom whitelisted dependencies [here](/custom-whitelisted-dependencies). +- `security`: The “security” parameter allows for three levels depending on specific use cases: “none,” “standard,” and “advanced.” "standard" and "advanced" are especially useful for detecting malicious intent from user queries and avoiding the execution of potentially harmful code. By default, the “security” is set to "standard." The security check might introduce stricter rules that could flag benign queries as harmful. You can deactivate it in the configuration by setting “security” to “none.” ## Demo in Google Colab diff --git a/pandasai/agent/base.py b/pandasai/agent/base.py index c6203111d..0c29ada1c 100644 --- a/pandasai/agent/base.py +++ b/pandasai/agent/base.py @@ -259,7 +259,10 @@ def chat(self, query: str, output_type: Optional[str] = None): self.assign_prompt_id() - if self.check_malicious_keywords_in_query(query): + if self.config.security in [ + "standard", + "advanced", + ] and self.check_malicious_keywords_in_query(query): raise MaliciousQueryError( "The query contains references to io or os modules or b64decode method which can be used to execute or access system resources in unsafe ways." ) diff --git a/pandasai/helpers/optional.py b/pandasai/helpers/optional.py index 2bcdcc146..3b3f23289 100644 --- a/pandasai/helpers/optional.py +++ b/pandasai/helpers/optional.py @@ -51,7 +51,7 @@ def get_version(module: types.ModuleType) -> str: return version -def get_environment(additional_deps: List[dict]) -> dict: +def get_environment(additional_deps: List[dict], secure: bool = True) -> dict: """ Returns the environment for the code to be executed. @@ -73,24 +73,43 @@ def get_environment(additional_deps: List[dict]) -> dict: }, } - env["pd"] = RestrictedPandas() - env["plt"] = RestrictedMatplotlib() - env["np"] = RestrictedNumpy() + if secure: + env["pd"] = RestrictedPandas() + env["plt"] = RestrictedMatplotlib() + env["np"] = RestrictedNumpy() - for lib in additional_deps: - if lib["name"] == "seaborn": - from pandasai.safe_libs.restricted_seaborn import RestrictedSeaborn + for lib in additional_deps: + if lib["name"] == "seaborn": + from pandasai.safe_libs.restricted_seaborn import RestrictedSeaborn - env["sns"] = RestrictedSeaborn() + env["sns"] = RestrictedSeaborn() - if lib["name"] == "datetime": - env["datetime"] = RestrictedDatetime() + if lib["name"] == "datetime": + env["datetime"] = RestrictedDatetime() - if lib["name"] == "json": - env["json"] = RestrictedJson() + if lib["name"] == "json": + env["json"] = RestrictedJson() - if lib["name"] == "base64": - env["base64"] = RestrictedBase64() + if lib["name"] == "base64": + env["base64"] = RestrictedBase64() + + else: + env["pd"] = import_dependency("pandas") + env["plt"] = import_dependency("matplotlib.pyplot") + env["np"] = import_dependency("numpy") + + for lib in additional_deps: + if lib["name"] == "seaborn": + env["sns"] = import_dependency("seaborn") + + if lib["name"] == "datetime": + env["datetime"] = import_dependency("datetime") + + if lib["name"] == "json": + env["json"] = import_dependency("json") + + if lib["name"] == "base64": + env["base64"] = import_dependency("base64") return env diff --git a/pandasai/pipelines/chat/code_cleaning.py b/pandasai/pipelines/chat/code_cleaning.py index 96dbd91f9..398c10cf9 100644 --- a/pandasai/pipelines/chat/code_cleaning.py +++ b/pandasai/pipelines/chat/code_cleaning.py @@ -121,10 +121,14 @@ def _replace_plot_png(self, code): return re.sub(r"""(['"])([^'"]*\.png)\1""", r"\1temp_chart.png\1", code) def get_code_to_run(self, code: str, context: CodeExecutionContext) -> Any: - if self._is_malicious_code(code): + if self._config.security in [ + "standard", + "advanced", + ] and self._is_malicious_code(code): raise MaliciousQueryError( "Code shouldn't use 'os', 'io' or 'chr', 'b64decode' functions as this could lead to malicious code execution." ) + code = self._replace_plot_png(code) self._current_code_executed = code @@ -475,7 +479,10 @@ def _extract_fix_dataframe_redeclarations( if target_names and self._check_is_df_declaration(node): # Construct dataframe from node code = "\n".join(code_lines) - env = get_environment(self._additional_dependencies) + env = get_environment( + self._additional_dependencies, + secure=self._config.security in ["standard", "advanced"], + ) env["dfs"] = copy.deepcopy(self._get_originals(self._dfs)) if context.skills_manager.used_skills: for skill_func_name in context.skills_manager.used_skills: diff --git a/pandasai/pipelines/chat/code_execution.py b/pandasai/pipelines/chat/code_execution.py index 65acf4656..6ee25ce22 100644 --- a/pandasai/pipelines/chat/code_execution.py +++ b/pandasai/pipelines/chat/code_execution.py @@ -153,7 +153,10 @@ def execute_code(self, code: str, context: CodeExecutionContext) -> Any: # List the required dfs, so we can avoid to run the connectors # if the code does not need them dfs = self._required_dfs(code) - environment: dict = get_environment(self._additional_dependencies) + environment: dict = get_environment( + self._additional_dependencies, + secure=self._config.security in ["standard", "advanced"], + ) environment["dfs"] = self._get_originals(dfs) if len(environment["dfs"]) == 1: environment["df"] = environment["dfs"][0] diff --git a/pandasai/safe_libs/base_restricted_module.py b/pandasai/safe_libs/base_restricted_module.py index 3067a3aab..65e1864bd 100644 --- a/pandasai/safe_libs/base_restricted_module.py +++ b/pandasai/safe_libs/base_restricted_module.py @@ -4,7 +4,7 @@ def wrapper(*args, **kwargs): # Check for any suspicious arguments that might be used for importing for arg in args + tuple(kwargs.values()): if isinstance(arg, str) and any( - module in arg.lower() + module == arg.lower() for module in ["io", "os", "subprocess", "sys", "importlib"] ): raise SecurityError( diff --git a/pandasai/schemas/df_config.py b/pandasai/schemas/df_config.py index eec8307c4..398c466d7 100644 --- a/pandasai/schemas/df_config.py +++ b/pandasai/schemas/df_config.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, TypedDict +from typing import Any, List, Literal, Optional, TypedDict from pandasai.constants import DEFAULT_CHART_DIRECTORY from pandasai.helpers.dataframe_serializer import DataframeSerializerType @@ -30,6 +30,7 @@ class Config(BaseModel): log_server: LogServerConfig = None direct_sql: bool = False dataframe_serializer: DataframeSerializerType = DataframeSerializerType.CSV + security: Literal["standard", "none", "advanced"] = "standard" class Config: arbitrary_types_allowed = True diff --git a/tests/unit_tests/agent/test_agent.py b/tests/unit_tests/agent/test_agent.py index 59b9e6958..c1e9986fd 100644 --- a/tests/unit_tests/agent/test_agent.py +++ b/tests/unit_tests/agent/test_agent.py @@ -710,3 +710,37 @@ def test_query_detection(self, sample_df, config): for query in safe_queries: response = agent.chat(query) assert "Unfortunately, I was not able to get your answers" not in response + + def test_query_detection_disable_security(self, sample_df, config): + config["security"] = "none" + agent = Agent(sample_df, config, memory_size=10) + + malicious_queries = [ + "import os", + "import io", + "chr(97)", + "base64.b64decode", + "file = open('file.txt', 'os')", + "os.system('rm -rf /')", + "io.open('file.txt', 'w')", + ] + + expected_malicious_response = ( + """Unfortunately, I was not able to get your answers, because of the following error:\n\n""" + """The query contains references to io or os modules or b64decode method which can be used to execute or access system resources in unsafe ways.\n""" + ) + + for query in malicious_queries: + response = agent.chat(query) + assert response != expected_malicious_response + + safe_queries = [ + "print('Hello world')", + "through osmosis", + "the ionosphere", + "the capital of Norway is Oslo", + ] + + for query in safe_queries: + response = agent.chat(query) + assert "Unfortunately, I was not able to get your answers" not in response diff --git a/tests/unit_tests/helpers/test_optional_dependency.py b/tests/unit_tests/helpers/test_optional_dependency.py index db7ed2c2f..e06008ed7 100644 --- a/tests/unit_tests/helpers/test_optional_dependency.py +++ b/tests/unit_tests/helpers/test_optional_dependency.py @@ -9,6 +9,9 @@ import pytest from pandasai.helpers.optional import VERSIONS, get_environment, import_dependency +from pandasai.safe_libs.restricted_matplotlib import RestrictedMatplotlib +from pandasai.safe_libs.restricted_numpy import RestrictedNumpy +from pandasai.safe_libs.restricted_pandas import RestrictedPandas def test_import_optional(): @@ -91,3 +94,15 @@ def test_env_for_necessary_deps(): assert "pd" in env assert "plt" in env assert "np" in env + + +def test_env_for_security(): + env = get_environment([], secure=True) + assert "pd" in env and isinstance(env["pd"], RestrictedPandas) + assert "plt" in env and isinstance(env["plt"], RestrictedMatplotlib) + assert "np" in env and isinstance(env["np"], RestrictedNumpy) + + env = get_environment([], secure=False) + assert "pd" in env and not isinstance(env["pd"], RestrictedPandas) + assert "plt" in env and not isinstance(env["plt"], RestrictedMatplotlib) + assert "np" in env and not isinstance(env["np"], RestrictedNumpy) diff --git a/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py b/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py index 9506a06bb..be3c83963 100644 --- a/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py +++ b/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py @@ -581,13 +581,16 @@ def test_clean_code_using_multi_incorrect_sql_table( assert str(excinfo.value) == ("Query uses unauthorized table: table1.") @patch("pandasai.connectors.pandas.PandasConnector.head") - def test_fix_dataframe_redeclarations(self, mock_head, context: PipelineContext): + def test_fix_dataframe_redeclarations( + self, mock_head, context: PipelineContext, config: dict + ): df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) mock_head.return_value = df pandas_connector = PandasConnector({"original_df": df}) code_cleaning = CodeCleaning() code_cleaning._dfs = [pandas_connector] + code_cleaning._config = Config(**config) context.dfs = [pandas_connector] python_code = """ @@ -605,7 +608,7 @@ def test_fix_dataframe_redeclarations(self, mock_head, context: PipelineContext) @patch("pandasai.connectors.pandas.PandasConnector.head") def test_fix_dataframe_multiline_redeclarations( - self, mock_head, context: PipelineContext + self, mock_head, context: PipelineContext, config: dict ): df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) mock_head.return_value = df @@ -613,6 +616,7 @@ def test_fix_dataframe_multiline_redeclarations( code_cleaning = CodeCleaning() code_cleaning._dfs = [pandas_connector] + code_cleaning._config = Config(**config) context.dfs = [pandas_connector] python_code = """ @@ -664,7 +668,7 @@ def test_fix_dataframe_no_redeclarations(self, mock_head, context: PipelineConte @patch("pandasai.connectors.pandas.PandasConnector.head") def test_fix_dataframe_redeclarations_with_subscript( - self, mock_head, context: PipelineContext + self, mock_head, context: PipelineContext, config: dict ): df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) mock_head.return_value = df @@ -672,6 +676,7 @@ def test_fix_dataframe_redeclarations_with_subscript( code_cleaning = CodeCleaning() code_cleaning._dfs = [pandas_connector] + code_cleaning._config = Config(**config) context.dfs = [pandas_connector] python_code = """ @@ -689,7 +694,7 @@ def test_fix_dataframe_redeclarations_with_subscript( @patch("pandasai.connectors.pandas.PandasConnector.head") def test_fix_dataframe_redeclarations_with_subscript_and_data_variable( - self, mock_head, context: PipelineContext + self, mock_head, context: PipelineContext, config: dict ): data = { "country": ["China", "United States", "Japan", "Germany", "United Kingdom"], @@ -701,6 +706,7 @@ def test_fix_dataframe_redeclarations_with_subscript_and_data_variable( code_cleaning = CodeCleaning() code_cleaning._dfs = [pandas_connector] + code_cleaning._config = Config(**config) context.dfs = [pandas_connector] python_code = """ @@ -723,7 +729,7 @@ def test_fix_dataframe_redeclarations_with_subscript_and_data_variable( @patch("pandasai.connectors.pandas.PandasConnector.head") def test_fix_dataframe_redeclarations_and_data_variable( - self, mock_head, context: PipelineContext + self, mock_head, context: PipelineContext, config: Config ): data = { "country": ["China", "United States", "Japan", "Germany", "United Kingdom"], @@ -735,6 +741,7 @@ def test_fix_dataframe_redeclarations_and_data_variable( code_cleaning = CodeCleaning() code_cleaning._dfs = [pandas_connector] + code_cleaning._config = Config(**config) context.dfs = [pandas_connector] python_code = """ @@ -933,6 +940,19 @@ def test_clean_code_raise_import_with_restricted_using_import_statement( with pytest.raises(MaliciousQueryError): code_cleaning.execute(malicious_code, context=context, logger=logger) + def test_clean_code_raise_not_whitelisted_lib_with_none_security( + self, + code_cleaning: CodeCleaning, + context: PipelineContext, + logger: Logger, + ): + builtins_code = """import scipy +result = {'type': 'number', 'value': set([1, 2, 3])}""" + + context.config.security = "none" + with pytest.raises(BadImportError): + code_cleaning.execute(builtins_code, context=context, logger=logger) + def test_clean_code_with_pltshow_in_code( self, code_cleaning: CodeCleaning,