Skip to content

feat(security): add security config to disable it #1498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 2, 2025
5 changes: 4 additions & 1 deletion pandasai/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ def chat(self, query: str, output_type: Optional[str] = None):

self.assign_prompt_id()

if self.check_malicious_keywords_in_query(query):
if self.config.security in [
"standard",
"advanced",
] and self.check_malicious_keywords_in_query(query):
raise MaliciousQueryError(
"The query contains references to io or os modules or b64decode method which can be used to execute or access system resources in unsafe ways."
)
Expand Down
1 change: 0 additions & 1 deletion pandasai/ee/vectorstores/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import chromadb
from chromadb import config
from chromadb.utils import embedding_functions

from pandasai.helpers.logger import Logger
from pandasai.helpers.path import find_project_root
from pandasai.vectorstores.vectorstore import VectorStore
Expand Down
45 changes: 32 additions & 13 deletions pandasai/helpers/optional.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_version(module: types.ModuleType) -> str:
return version


def get_environment(additional_deps: List[dict]) -> dict:
def get_environment(additional_deps: List[dict], secure: bool = True) -> dict:
"""
Returns the environment for the code to be executed.

Expand All @@ -74,22 +74,41 @@ def get_environment(additional_deps: List[dict]) -> dict:
},
}

env["pd"] = RestrictedPandas()
env["plt"] = RestrictedMatplotlib()
env["np"] = RestrictedNumpy()
if secure:
env["pd"] = RestrictedPandas()
env["plt"] = RestrictedMatplotlib()
env["np"] = RestrictedNumpy()

for lib in additional_deps:
if lib["name"] == "seaborn":
env["sns"] = RestrictedSeaborn()
for lib in additional_deps:
if lib["name"] == "seaborn":
env["sns"] = RestrictedSeaborn()

if lib["name"] == "datetime":
env["datetime"] = RestrictedDatetime()
if lib["name"] == "datetime":
env["datetime"] = RestrictedDatetime()

if lib["name"] == "json":
env["json"] = RestrictedJson()
if lib["name"] == "json":
env["json"] = RestrictedJson()

if lib["name"] == "base64":
env["base64"] = RestrictedBase64()
if lib["name"] == "base64":
env["base64"] = RestrictedBase64()

else:
env["pd"] = import_dependency("pandas")
env["plt"] = import_dependency("matplotlib.pyplot")
env["np"] = import_dependency("numpy")

for lib in additional_deps:
if lib["name"] == "seaborn":
env["sns"] = import_dependency("seaborn")

if lib["name"] == "datetime":
env["datetime"] = import_dependency("datetime")

if lib["name"] == "json":
env["json"] = import_dependency("json")

if lib["name"] == "base64":
env["base64"] = import_dependency("base64")

return env

Expand Down
11 changes: 9 additions & 2 deletions pandasai/pipelines/chat/code_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,14 @@ def _replace_plot_png(self, code):
return re.sub(r"""(['"])([^'"]*\.png)\1""", r"\1temp_chart.png\1", code)

def get_code_to_run(self, code: str, context: CodeExecutionContext) -> Any:
if self._is_malicious_code(code):
if self._config.security in [
"standard",
"advanced",
] and self._is_malicious_code(code):
raise MaliciousQueryError(
"Code shouldn't use 'os', 'io' or 'chr', 'b64decode' functions as this could lead to malicious code execution."
)

code = self._replace_plot_png(code)
self._current_code_executed = code

Expand Down Expand Up @@ -471,7 +475,10 @@ def _extract_fix_dataframe_redeclarations(
if target_names and self._check_is_df_declaration(node):
# Construct dataframe from node
code = "\n".join(code_lines)
env = get_environment(self._additional_dependencies)
env = get_environment(
self._additional_dependencies,
secure=self._config.security in ["standard", "advanced"],
)
env["dfs"] = copy.deepcopy(self._get_originals(self._dfs))
if context.skills_manager.used_skills:
for skill_func_name in context.skills_manager.used_skills:
Expand Down
5 changes: 4 additions & 1 deletion pandasai/pipelines/chat/code_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ def execute_code(self, code: str, context: CodeExecutionContext) -> Any:
# List the required dfs, so we can avoid to run the connectors
# if the code does not need them
dfs = self._required_dfs(code)
environment: dict = get_environment(self._additional_dependencies)
environment: dict = get_environment(
self._additional_dependencies,
secure=self._config.security in ["standard", "advanced"],
)
environment["dfs"] = self._get_originals(dfs)
if len(environment["dfs"]) == 1:
environment["df"] = environment["dfs"][0]
Expand Down
2 changes: 1 addition & 1 deletion pandasai/safe_libs/base_restricted_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ def wrapper(*args, **kwargs):
# Check for any suspicious arguments that might be used for importing
for arg in args + tuple(kwargs.values()):
if isinstance(arg, str) and any(
module in arg.lower()
module == arg.lower()
for module in ["io", "os", "subprocess", "sys", "importlib"]
):
raise SecurityError(
Expand Down
3 changes: 2 additions & 1 deletion pandasai/schemas/df_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional, TypedDict
from typing import Any, List, Literal, Optional, TypedDict

from pandasai.constants import DEFAULT_CHART_DIRECTORY
from pandasai.helpers.dataframe_serializer import DataframeSerializerType
Expand Down Expand Up @@ -30,6 +30,7 @@ class Config(BaseModel):
log_server: LogServerConfig = None
direct_sql: bool = False
dataframe_serializer: DataframeSerializerType = DataframeSerializerType.CSV
security: Literal["standard", "none", "advanced"] = "standard"

class Config:
arbitrary_types_allowed = True
Expand Down
15 changes: 15 additions & 0 deletions tests/unit_tests/helpers/test_optional_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import pytest

from pandasai.helpers.optional import VERSIONS, get_environment, import_dependency
from pandasai.safe_libs.restricted_matplotlib import RestrictedMatplotlib
from pandasai.safe_libs.restricted_numpy import RestrictedNumpy
from pandasai.safe_libs.restricted_pandas import RestrictedPandas


def test_import_optional():
Expand Down Expand Up @@ -91,3 +94,15 @@ def test_env_for_necessary_deps():
assert "pd" in env
assert "plt" in env
assert "np" in env


def test_env_for_security():
env = get_environment([], secure=True)
assert "pd" in env and isinstance(env["pd"], RestrictedPandas)
assert "plt" in env and isinstance(env["plt"], RestrictedMatplotlib)
assert "np" in env and isinstance(env["np"], RestrictedNumpy)

env = get_environment([], secure=False)
assert "pd" in env and not isinstance(env["pd"], RestrictedPandas)
assert "plt" in env and not isinstance(env["plt"], RestrictedMatplotlib)
assert "np" in env and not isinstance(env["np"], RestrictedNumpy)
30 changes: 25 additions & 5 deletions tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,13 +577,16 @@ def test_clean_code_using_multi_incorrect_sql_table(
assert str(excinfo.value) == ("Query uses unauthorized table: table1.")

@patch("pandasai.connectors.pandas.PandasConnector.head")
def test_fix_dataframe_redeclarations(self, mock_head, context: PipelineContext):
def test_fix_dataframe_redeclarations(
self, mock_head, context: PipelineContext, config: dict
):
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
mock_head.return_value = df
pandas_connector = PandasConnector({"original_df": df})

code_cleaning = CodeCleaning()
code_cleaning._dfs = [pandas_connector]
code_cleaning._config = Config(**config)
context.dfs = [pandas_connector]

python_code = """
Expand All @@ -601,14 +604,15 @@ def test_fix_dataframe_redeclarations(self, mock_head, context: PipelineContext)

@patch("pandasai.connectors.pandas.PandasConnector.head")
def test_fix_dataframe_multiline_redeclarations(
self, mock_head, context: PipelineContext
self, mock_head, context: PipelineContext, config: dict
):
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
mock_head.return_value = df
pandas_connector = PandasConnector({"original_df": df})

code_cleaning = CodeCleaning()
code_cleaning._dfs = [pandas_connector]
code_cleaning._config = Config(**config)
context.dfs = [pandas_connector]

python_code = """
Expand Down Expand Up @@ -660,14 +664,15 @@ def test_fix_dataframe_no_redeclarations(self, mock_head, context: PipelineConte

@patch("pandasai.connectors.pandas.PandasConnector.head")
def test_fix_dataframe_redeclarations_with_subscript(
self, mock_head, context: PipelineContext
self, mock_head, context: PipelineContext, config: dict
):
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
mock_head.return_value = df
pandas_connector = PandasConnector({"original_df": df})

code_cleaning = CodeCleaning()
code_cleaning._dfs = [pandas_connector]
code_cleaning._config = Config(**config)
context.dfs = [pandas_connector]

python_code = """
Expand All @@ -685,7 +690,7 @@ def test_fix_dataframe_redeclarations_with_subscript(

@patch("pandasai.connectors.pandas.PandasConnector.head")
def test_fix_dataframe_redeclarations_with_subscript_and_data_variable(
self, mock_head, context: PipelineContext
self, mock_head, context: PipelineContext, config: dict
):
data = {
"country": ["China", "United States", "Japan", "Germany", "United Kingdom"],
Expand All @@ -697,6 +702,7 @@ def test_fix_dataframe_redeclarations_with_subscript_and_data_variable(

code_cleaning = CodeCleaning()
code_cleaning._dfs = [pandas_connector]
code_cleaning._config = Config(**config)
context.dfs = [pandas_connector]

python_code = """
Expand All @@ -719,7 +725,7 @@ def test_fix_dataframe_redeclarations_with_subscript_and_data_variable(

@patch("pandasai.connectors.pandas.PandasConnector.head")
def test_fix_dataframe_redeclarations_and_data_variable(
self, mock_head, context: PipelineContext
self, mock_head, context: PipelineContext, config: Config
):
data = {
"country": ["China", "United States", "Japan", "Germany", "United Kingdom"],
Expand All @@ -731,6 +737,7 @@ def test_fix_dataframe_redeclarations_and_data_variable(

code_cleaning = CodeCleaning()
code_cleaning._dfs = [pandas_connector]
code_cleaning._config = Config(**config)
context.dfs = [pandas_connector]

python_code = """
Expand Down Expand Up @@ -928,3 +935,16 @@ def test_clean_code_raise_import_with_restricted_using_import_statement(
"""
with pytest.raises(MaliciousQueryError):
code_cleaning.execute(malicious_code, context=context, logger=logger)

def test_clean_code_raise_not_whitelisted_lib_with_none_security(
self,
code_cleaning: CodeCleaning,
context: PipelineContext,
logger: Logger,
):
builtins_code = """import scipy
result = {'type': 'number', 'value': set([1, 2, 3])}"""

context.config.security = "none"
with pytest.raises(BadImportError):
code_cleaning.execute(builtins_code, context=context, logger=logger)
Loading