Skip to content

Commit

Permalink
feat(security): add security config to disable it (#1498)
Browse files Browse the repository at this point in the history
* feat(security): add security config to disable it

* fix: linting errors

* fix(safety): push exact match for get attributes

* add additional test case

* fix: test case

* fix:  linting errors

* fix: linting errors

* docs(config): update config doc to add new config attribute
  • Loading branch information
ArslanSaleem authored Jan 2, 2025
1 parent cfeb071 commit 554a638
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 25 deletions.
1 change: 1 addition & 0 deletions docs/library.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ Settings:
- `use_error_correction_framework`: whether to use the error correction framework. Defaults to `True`. If set to `True`, PandasAI will try to correct the errors in the code generated by the LLM with further calls to the LLM. If set to `False`, PandasAI will not try to correct the errors in the code generated by the LLM.
- `max_retries`: the maximum number of retries to use when using the error correction framework. Defaults to `3`. You can use this setting to override the default number of retries.
- `custom_whitelisted_dependencies`: the custom whitelisted dependencies to use. Defaults to `{}`. You can use this setting to override the default custom whitelisted dependencies. You can find more information about custom whitelisted dependencies [here](/custom-whitelisted-dependencies).
- `security`: The “security” parameter allows for three levels depending on specific use cases: “none,” “standard,” and “advanced.” "standard" and "advanced" are especially useful for detecting malicious intent from user queries and avoiding the execution of potentially harmful code. By default, the “security” is set to "standard." The security check might introduce stricter rules that could flag benign queries as harmful. You can deactivate it in the configuration by setting “security” to “none.”

## Demo in Google Colab

Expand Down
5 changes: 4 additions & 1 deletion pandasai/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ def chat(self, query: str, output_type: Optional[str] = None):

self.assign_prompt_id()

if self.check_malicious_keywords_in_query(query):
if self.config.security in [
"standard",
"advanced",
] and self.check_malicious_keywords_in_query(query):
raise MaliciousQueryError(
"The query contains references to io or os modules or b64decode method which can be used to execute or access system resources in unsafe ways."
)
Expand Down
47 changes: 33 additions & 14 deletions pandasai/helpers/optional.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_version(module: types.ModuleType) -> str:
return version


def get_environment(additional_deps: List[dict]) -> dict:
def get_environment(additional_deps: List[dict], secure: bool = True) -> dict:
"""
Returns the environment for the code to be executed.
Expand All @@ -73,24 +73,43 @@ def get_environment(additional_deps: List[dict]) -> dict:
},
}

env["pd"] = RestrictedPandas()
env["plt"] = RestrictedMatplotlib()
env["np"] = RestrictedNumpy()
if secure:
env["pd"] = RestrictedPandas()
env["plt"] = RestrictedMatplotlib()
env["np"] = RestrictedNumpy()

for lib in additional_deps:
if lib["name"] == "seaborn":
from pandasai.safe_libs.restricted_seaborn import RestrictedSeaborn
for lib in additional_deps:
if lib["name"] == "seaborn":
from pandasai.safe_libs.restricted_seaborn import RestrictedSeaborn

env["sns"] = RestrictedSeaborn()
env["sns"] = RestrictedSeaborn()

if lib["name"] == "datetime":
env["datetime"] = RestrictedDatetime()
if lib["name"] == "datetime":
env["datetime"] = RestrictedDatetime()

if lib["name"] == "json":
env["json"] = RestrictedJson()
if lib["name"] == "json":
env["json"] = RestrictedJson()

if lib["name"] == "base64":
env["base64"] = RestrictedBase64()
if lib["name"] == "base64":
env["base64"] = RestrictedBase64()

else:
env["pd"] = import_dependency("pandas")
env["plt"] = import_dependency("matplotlib.pyplot")
env["np"] = import_dependency("numpy")

for lib in additional_deps:
if lib["name"] == "seaborn":
env["sns"] = import_dependency("seaborn")

if lib["name"] == "datetime":
env["datetime"] = import_dependency("datetime")

if lib["name"] == "json":
env["json"] = import_dependency("json")

if lib["name"] == "base64":
env["base64"] = import_dependency("base64")

return env

Expand Down
11 changes: 9 additions & 2 deletions pandasai/pipelines/chat/code_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,14 @@ def _replace_plot_png(self, code):
return re.sub(r"""(['"])([^'"]*\.png)\1""", r"\1temp_chart.png\1", code)

def get_code_to_run(self, code: str, context: CodeExecutionContext) -> Any:
if self._is_malicious_code(code):
if self._config.security in [
"standard",
"advanced",
] and self._is_malicious_code(code):
raise MaliciousQueryError(
"Code shouldn't use 'os', 'io' or 'chr', 'b64decode' functions as this could lead to malicious code execution."
)

code = self._replace_plot_png(code)
self._current_code_executed = code

Expand Down Expand Up @@ -475,7 +479,10 @@ def _extract_fix_dataframe_redeclarations(
if target_names and self._check_is_df_declaration(node):
# Construct dataframe from node
code = "\n".join(code_lines)
env = get_environment(self._additional_dependencies)
env = get_environment(
self._additional_dependencies,
secure=self._config.security in ["standard", "advanced"],
)
env["dfs"] = copy.deepcopy(self._get_originals(self._dfs))
if context.skills_manager.used_skills:
for skill_func_name in context.skills_manager.used_skills:
Expand Down
5 changes: 4 additions & 1 deletion pandasai/pipelines/chat/code_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ def execute_code(self, code: str, context: CodeExecutionContext) -> Any:
# List the required dfs, so we can avoid to run the connectors
# if the code does not need them
dfs = self._required_dfs(code)
environment: dict = get_environment(self._additional_dependencies)
environment: dict = get_environment(
self._additional_dependencies,
secure=self._config.security in ["standard", "advanced"],
)
environment["dfs"] = self._get_originals(dfs)
if len(environment["dfs"]) == 1:
environment["df"] = environment["dfs"][0]
Expand Down
2 changes: 1 addition & 1 deletion pandasai/safe_libs/base_restricted_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ def wrapper(*args, **kwargs):
# Check for any suspicious arguments that might be used for importing
for arg in args + tuple(kwargs.values()):
if isinstance(arg, str) and any(
module in arg.lower()
module == arg.lower()
for module in ["io", "os", "subprocess", "sys", "importlib"]
):
raise SecurityError(
Expand Down
3 changes: 2 additions & 1 deletion pandasai/schemas/df_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional, TypedDict
from typing import Any, List, Literal, Optional, TypedDict

from pandasai.constants import DEFAULT_CHART_DIRECTORY
from pandasai.helpers.dataframe_serializer import DataframeSerializerType
Expand Down Expand Up @@ -30,6 +30,7 @@ class Config(BaseModel):
log_server: LogServerConfig = None
direct_sql: bool = False
dataframe_serializer: DataframeSerializerType = DataframeSerializerType.CSV
security: Literal["standard", "none", "advanced"] = "standard"

class Config:
arbitrary_types_allowed = True
Expand Down
34 changes: 34 additions & 0 deletions tests/unit_tests/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,3 +710,37 @@ def test_query_detection(self, sample_df, config):
for query in safe_queries:
response = agent.chat(query)
assert "Unfortunately, I was not able to get your answers" not in response

def test_query_detection_disable_security(self, sample_df, config):
config["security"] = "none"
agent = Agent(sample_df, config, memory_size=10)

malicious_queries = [
"import os",
"import io",
"chr(97)",
"base64.b64decode",
"file = open('file.txt', 'os')",
"os.system('rm -rf /')",
"io.open('file.txt', 'w')",
]

expected_malicious_response = (
"""Unfortunately, I was not able to get your answers, because of the following error:\n\n"""
"""The query contains references to io or os modules or b64decode method which can be used to execute or access system resources in unsafe ways.\n"""
)

for query in malicious_queries:
response = agent.chat(query)
assert response != expected_malicious_response

safe_queries = [
"print('Hello world')",
"through osmosis",
"the ionosphere",
"the capital of Norway is Oslo",
]

for query in safe_queries:
response = agent.chat(query)
assert "Unfortunately, I was not able to get your answers" not in response
15 changes: 15 additions & 0 deletions tests/unit_tests/helpers/test_optional_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import pytest

from pandasai.helpers.optional import VERSIONS, get_environment, import_dependency
from pandasai.safe_libs.restricted_matplotlib import RestrictedMatplotlib
from pandasai.safe_libs.restricted_numpy import RestrictedNumpy
from pandasai.safe_libs.restricted_pandas import RestrictedPandas


def test_import_optional():
Expand Down Expand Up @@ -91,3 +94,15 @@ def test_env_for_necessary_deps():
assert "pd" in env
assert "plt" in env
assert "np" in env


def test_env_for_security():
env = get_environment([], secure=True)
assert "pd" in env and isinstance(env["pd"], RestrictedPandas)
assert "plt" in env and isinstance(env["plt"], RestrictedMatplotlib)
assert "np" in env and isinstance(env["np"], RestrictedNumpy)

env = get_environment([], secure=False)
assert "pd" in env and not isinstance(env["pd"], RestrictedPandas)
assert "plt" in env and not isinstance(env["plt"], RestrictedMatplotlib)
assert "np" in env and not isinstance(env["np"], RestrictedNumpy)
30 changes: 25 additions & 5 deletions tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,13 +581,16 @@ def test_clean_code_using_multi_incorrect_sql_table(
assert str(excinfo.value) == ("Query uses unauthorized table: table1.")

@patch("pandasai.connectors.pandas.PandasConnector.head")
def test_fix_dataframe_redeclarations(self, mock_head, context: PipelineContext):
def test_fix_dataframe_redeclarations(
self, mock_head, context: PipelineContext, config: dict
):
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
mock_head.return_value = df
pandas_connector = PandasConnector({"original_df": df})

code_cleaning = CodeCleaning()
code_cleaning._dfs = [pandas_connector]
code_cleaning._config = Config(**config)
context.dfs = [pandas_connector]

python_code = """
Expand All @@ -605,14 +608,15 @@ def test_fix_dataframe_redeclarations(self, mock_head, context: PipelineContext)

@patch("pandasai.connectors.pandas.PandasConnector.head")
def test_fix_dataframe_multiline_redeclarations(
self, mock_head, context: PipelineContext
self, mock_head, context: PipelineContext, config: dict
):
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
mock_head.return_value = df
pandas_connector = PandasConnector({"original_df": df})

code_cleaning = CodeCleaning()
code_cleaning._dfs = [pandas_connector]
code_cleaning._config = Config(**config)
context.dfs = [pandas_connector]

python_code = """
Expand Down Expand Up @@ -664,14 +668,15 @@ def test_fix_dataframe_no_redeclarations(self, mock_head, context: PipelineConte

@patch("pandasai.connectors.pandas.PandasConnector.head")
def test_fix_dataframe_redeclarations_with_subscript(
self, mock_head, context: PipelineContext
self, mock_head, context: PipelineContext, config: dict
):
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
mock_head.return_value = df
pandas_connector = PandasConnector({"original_df": df})

code_cleaning = CodeCleaning()
code_cleaning._dfs = [pandas_connector]
code_cleaning._config = Config(**config)
context.dfs = [pandas_connector]

python_code = """
Expand All @@ -689,7 +694,7 @@ def test_fix_dataframe_redeclarations_with_subscript(

@patch("pandasai.connectors.pandas.PandasConnector.head")
def test_fix_dataframe_redeclarations_with_subscript_and_data_variable(
self, mock_head, context: PipelineContext
self, mock_head, context: PipelineContext, config: dict
):
data = {
"country": ["China", "United States", "Japan", "Germany", "United Kingdom"],
Expand All @@ -701,6 +706,7 @@ def test_fix_dataframe_redeclarations_with_subscript_and_data_variable(

code_cleaning = CodeCleaning()
code_cleaning._dfs = [pandas_connector]
code_cleaning._config = Config(**config)
context.dfs = [pandas_connector]

python_code = """
Expand All @@ -723,7 +729,7 @@ def test_fix_dataframe_redeclarations_with_subscript_and_data_variable(

@patch("pandasai.connectors.pandas.PandasConnector.head")
def test_fix_dataframe_redeclarations_and_data_variable(
self, mock_head, context: PipelineContext
self, mock_head, context: PipelineContext, config: Config
):
data = {
"country": ["China", "United States", "Japan", "Germany", "United Kingdom"],
Expand All @@ -735,6 +741,7 @@ def test_fix_dataframe_redeclarations_and_data_variable(

code_cleaning = CodeCleaning()
code_cleaning._dfs = [pandas_connector]
code_cleaning._config = Config(**config)
context.dfs = [pandas_connector]

python_code = """
Expand Down Expand Up @@ -933,6 +940,19 @@ def test_clean_code_raise_import_with_restricted_using_import_statement(
with pytest.raises(MaliciousQueryError):
code_cleaning.execute(malicious_code, context=context, logger=logger)

def test_clean_code_raise_not_whitelisted_lib_with_none_security(
self,
code_cleaning: CodeCleaning,
context: PipelineContext,
logger: Logger,
):
builtins_code = """import scipy
result = {'type': 'number', 'value': set([1, 2, 3])}"""

context.config.security = "none"
with pytest.raises(BadImportError):
code_cleaning.execute(builtins_code, context=context, logger=logger)

def test_clean_code_with_pltshow_in_code(
self,
code_cleaning: CodeCleaning,
Expand Down

0 comments on commit 554a638

Please sign in to comment.