Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@
logger = logging.getLogger(__name__)


def _normalize_resource_url(resource: str) -> str:
parsed = urlparse(resource)
if parsed.path == "/" and not parsed.params and not parsed.query and not parsed.fragment:
return f"{parsed.scheme}://{parsed.netloc}"
return resource


class PKCEParameters(BaseModel):
"""PKCE (Proof Key for Code Exchange) parameters."""

Expand Down Expand Up @@ -151,7 +158,7 @@ def get_resource_url(self) -> str:

# If PRM provides a resource that's a valid parent, use it
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
prm_resource = str(self.protected_resource_metadata.resource)
prm_resource = _normalize_resource_url(str(self.protected_resource_metadata.resource))
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
resource = prm_resource

Expand Down Expand Up @@ -442,10 +449,6 @@ async def _refresh_token(self) -> httpx.Request:
"client_id": self.context.client_info.client_id,
}

# Only include resource param if conditions are met
if self.context.should_include_resource_param(self.context.protocol_version):
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707

# Prepare authentication based on preferred method
headers = {"Content-Type": "application/x-www-form-urlencoded"}
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)
Expand Down
37 changes: 33 additions & 4 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ class TestProtectedResourceMetadata:

@pytest.mark.anyio
async def test_resource_param_included_with_recent_protocol_version(self, oauth_provider: OAuthClientProvider):
"""Test resource parameter is included for protocol version >= 2025-06-18."""
"""Test resource parameter is included for initial token requests on newer protocol versions."""
# Set protocol version to 2025-06-18
oauth_provider.context.protocol_version = "2025-06-18"
oauth_provider.context.client_info = OAuthClientInformationFull(
Expand All @@ -762,15 +762,16 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_
expected_resource = quote(oauth_provider.context.get_resource_url(), safe="")
assert f"resource={expected_resource}" in content

# Test in refresh token
# Refresh tokens should not resend the resource parameter. Some providers
# reject RFC 8707 resource values on refresh_token grants.
oauth_provider.context.current_tokens = OAuthToken(
access_token="test_access",
token_type="Bearer",
refresh_token="test_refresh",
)
refresh_request = await oauth_provider._refresh_token()
refresh_content = refresh_request.content.decode()
assert "resource=" in refresh_content
assert "resource=" not in refresh_content

@pytest.mark.anyio
async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider):
Expand Down Expand Up @@ -800,7 +801,7 @@ async def test_resource_param_excluded_with_old_protocol_version(self, oauth_pro

@pytest.mark.anyio
async def test_resource_param_included_with_protected_resource_metadata(self, oauth_provider: OAuthClientProvider):
"""Test resource parameter is always included when protected resource metadata exists."""
"""Test resource parameter is included in initial token requests when PRM exists."""
# Set old protocol version but with protected resource metadata
oauth_provider.context.protocol_version = "2025-03-26"
oauth_provider.context.protected_resource_metadata = ProtectedResourceMetadata(
Expand All @@ -818,6 +819,15 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa
content = request.content.decode()
assert "resource=" in content

oauth_provider.context.current_tokens = OAuthToken(
access_token="test_access",
token_type="Bearer",
refresh_token="test_refresh",
)
refresh_request = await oauth_provider._refresh_token()
refresh_content = refresh_request.content.decode()
assert "resource=" not in refresh_content


@pytest.mark.anyio
async def test_validate_resource_rejects_mismatched_resource(
Expand Down Expand Up @@ -949,6 +959,25 @@ async def test_get_resource_url_uses_canonical_when_prm_mismatches(
assert provider.context.get_resource_url() == snapshot("https://api.example.com/v1/mcp")


@pytest.mark.anyio
async def test_get_resource_url_removes_root_prm_trailing_slash(
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
) -> None:
"""Bare-domain PRM resources should not pick up Pydantic's root slash."""
provider = OAuthClientProvider(
server_url="https://api.example.com",
client_metadata=client_metadata,
storage=mock_storage,
)
provider._initialized = True
provider.context.protected_resource_metadata = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://api.example.com"),
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
)

assert provider.context.get_resource_url() == snapshot("https://api.example.com")


class TestRegistrationResponse:
"""Test client registration response handling."""

Expand Down
Loading