Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion starlette/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocket

if False: # TYPE_CHECKING
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a built-in flag in typing for this :-)

Suggested change
if False: # TYPE_CHECKING
if typing.TYPE_CHECKING:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, makes sense (given that this is py36+), thanks!
Not going to update it for now though, given its staleness already.

import sys

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal


class HTTPEndpoint:
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
Expand Down Expand Up @@ -43,7 +51,7 @@ async def method_not_allowed(self, request: Request) -> Response:

class WebSocketEndpoint:

encoding = None # May be "text", "bytes", or "json".
encoding = None # type: typing.Optional[Literal["text", "bytes", "json"]]

def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] == "websocket"
Expand Down
2 changes: 1 addition & 1 deletion starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
response = PlainTextResponse("Not Found", status_code=404)
await response(scope, receive, send)

def url_path_for(self, name: str, **path_params: str) -> URLPath:
def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
for route in self.routes:
try:
return route.url_path_for(name, **path_params)
Expand Down
2 changes: 1 addition & 1 deletion starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def websocket_connect(

return session

def __enter__(self) -> requests.Session:
def __enter__(self) -> "TestClient":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have this fix merged.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we go #1064

loop = asyncio.get_event_loop()
self.send_queue = asyncio.Queue() # type: asyncio.Queue
self.receive_queue = asyncio.Queue() # type: asyncio.Queue
Expand Down
2 changes: 2 additions & 0 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ async def read_note(request):
note_id = request.path_params["note_id"]
query = notes.select().where(notes.c.id == note_id)
result = await database.fetch_one(query)
assert result
content = {"text": result["text"], "completed": result["completed"]}
return JSONResponse(content)

Expand All @@ -84,6 +85,7 @@ async def read_note_text(request):
note_id = request.path_params["note_id"]
query = sqlalchemy.select([notes.c.text]).where(notes.c.id == note_id)
result = await database.fetch_one(query)
assert result
return JSONResponse(result[0])


Expand Down
3 changes: 2 additions & 1 deletion tests/test_datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
MultiDict,
MutableHeaders,
QueryParams,
UploadFile,
)


Expand Down Expand Up @@ -211,7 +212,7 @@ def test_queryparams():


def test_formdata():
upload = io.BytesIO(b"test")
upload = UploadFile("upload", io.BytesIO(b"test"))
form = FormData([("a", "123"), ("a", "456"), ("b", upload)])
assert "a" in form
assert "A" not in form
Expand Down
4 changes: 3 additions & 1 deletion tests/test_formparsers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import typing

from starlette.formparsers import UploadFile, _user_safe_decode
from starlette.requests import Request
Expand All @@ -22,6 +23,7 @@ async def app(scope, receive, send):
for key, value in data.items():
if isinstance(value, UploadFile):
content = await value.read()
assert isinstance(content, bytes)
output[key] = {
"filename": value.filename,
"content": content.decode(),
Expand All @@ -37,7 +39,7 @@ async def app(scope, receive, send):
async def multi_items_app(scope, receive, send):
request = Request(scope, receive)
data = await request.form()
output = {}
output: typing.Dict[str, typing.List] = {}
for key, value in data.multi_items():
if key not in output:
output[key] = []
Expand Down
10 changes: 6 additions & 4 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,13 @@ async def subdomain_app(scope, receive, send):
await response(scope, receive, send)


subdomain_app = Router(
subdomain_router = Router(
routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")]
)


def test_subdomain_routing():
client = TestClient(subdomain_app, base_url="https://foo.example.org/")
client = TestClient(subdomain_router, base_url="https://foo.example.org/")

response = client.get("/")
assert response.status_code == 200
Expand All @@ -361,7 +361,7 @@ def test_subdomain_routing():

def test_subdomain_reverse_urls():
assert (
subdomain_app.url_path_for(
subdomain_router.url_path_for(
"subdomains", subdomain="foo", path="/homepage"
).make_absolute_url("https://whatever")
== "https://foo.example.org/homepage"
Expand Down Expand Up @@ -403,7 +403,9 @@ def test_url_for_with_root_path():


double_mount_routes = [
Mount("/mount", name="mount", routes=[Mount("/static", ..., name="static")],),
Mount(
"/mount", name="mount", routes=[Mount("/static", Starlette(), name="static")],
),
]


Expand Down
6 changes: 3 additions & 3 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ async def homepage(request):
client = TestClient(app)
response = client.get("/")
assert response.text == "<html>Hello, <a href='http://testserver/'>world</a></html>"
assert response.template.name == "index.html"
assert set(response.context.keys()) == {"request"}
assert response.template.name == "index.html" # type: ignore
assert set(response.context.keys()) == {"request"} # type: ignore


def test_template_response_requires_request(tmpdir):
templates = Jinja2Templates(str(tmpdir))
with pytest.raises(ValueError):
templates.TemplateResponse(None, {})
templates.TemplateResponse("name", {})