diff --git a/.env.example b/.env.example index 6e50e4344..20f420192 100644 --- a/.env.example +++ b/.env.example @@ -72,3 +72,11 @@ SPOOLMAN_PORT=7912 # Default if not set: 1000 #PUID=1000 #PGID=1000 + +# Allows CORS ORIGIN. +# Use the https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Access-Control-Allow-Origin semantics +# separated by commas +# for example to allow request from source1.domain.com on port p1 and source2.domain.com on port p2 +# SPOOLMAN_CORS_ORIGIN=source1.domain.com:p1, source2.domain.com:p2 +# to allow from any +# SPOOLMAN_CORS_ORIGIN=* diff --git a/spoolman/env.py b/spoolman/env.py index cbb4a40c5..dc90bdf8c 100644 --- a/spoolman/env.py +++ b/spoolman/env.py @@ -210,6 +210,35 @@ def is_debug_mode() -> bool: return True raise ValueError(f"Failed to parse SPOOLMAN_DEBUG_MODE variable: Unknown debug mode '{debug_mode}'.") +def is_cors_defined() -> bool: + """Get whether CORS is enabled from environment variables. + + Returns False if no environment variable was set for CORS. + Returns True otherwise + + Returns: + bool: Whether CORS is enabled. + + """ + cors = os.getenv("SPOOLMAN_CORS_ORIGIN", "FALSE").upper() + if cors in {"FALSE", "0"}: + return False + else: + return True + +def get_cors_origin() -> Optional[list[str]]: + """Get the CORS origin from environment variables. + + Returns None if no environment variable was set for the origin. + + Returns: + Optional[str]: The origin. + + """ + cors = os.getenv("SPOOLMAN_CORS_ORIGIN") + if cors is None: + return None + return cors.split(",") def is_automatic_backup_enabled() -> bool: """Get whether automatic backup is enabled from environment variables. diff --git a/spoolman/main.py b/spoolman/main.py index 1b75b2113..6e716078d 100644 --- a/spoolman/main.py +++ b/spoolman/main.py @@ -104,12 +104,16 @@ def get_configjs() -> Response: app.mount(base_path, app=SinglePageApplication(directory="client/dist", base_path=env.get_base_path())) # Allow all origins if in debug mode -if env.is_debug_mode(): - logger.warning("Running in debug mode, allowing all origins.") +if env.is_debug_mode() or env.is_cors_defined(): + if(env.is_cors_defined()): + origins = env.get_cors_origin() + else: + origins = ["*"] + logger.warning("Running in debug mode, allowing all origins.") app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"],