diff --git a/synapse/config/oidc2.py b/synapse/config/oidc2.py index a5f4b37b86..1e2f434867 100644 --- a/synapse/config/oidc2.py +++ b/synapse/config/oidc2.py @@ -101,13 +101,26 @@ class OIDCProviderModel(BaseModel): scopes: Tuple[StrictStr, ...] = ("openid",) # the oauth2 authorization endpoint. Required if discovery is disabled. - # TODO: required if discovery is disabled authorization_endpoint: Optional[StrictStr] # the oauth2 token endpoint. Required if discovery is disabled. - # TODO: required if discovery is disabled token_endpoint: Optional[StrictStr] + # Normally, validators aren't run when fields don't have a value provided. + # Using validate=True ensures we run the validator even in that situation. + @validator("authorization_endpoint", "token_endpoint", always=True) + def endpoints_required_if_discovery_disabled( + cls: Type["OIDCProviderModel"], + endpoint_url: Optional[str], + values: Mapping[str, Any], + field: ModelField, + ) -> Optional[str]: + # `if "discover" in values means: don't run our checks if "discover" didn't + # pass validation. (NB: validation order is the field definition order) + if "discover" in values and not values["discover"] and endpoint_url is None: + raise ValueError(f"{field.name} is required if discovery is disabled") + return endpoint_url + # the OIDC userinfo endpoint. Required if discovery is disabled and the # "openid" scope is not requested. # TODO: required if discovery is disabled and the openid scope isn't requested diff --git a/tests/config/test_oidc2.py b/tests/config/test_oidc2.py index 98e0edd37b..f12e7a5b2b 100644 --- a/tests/config/test_oidc2.py +++ b/tests/config/test_oidc2.py @@ -4,6 +4,7 @@ from unittest import TestCase import yaml from pydantic import ValidationError +from parameterized import parameterized from synapse.config.oidc2 import ( OIDCProviderModel, @@ -29,6 +30,7 @@ client_secret_jwt_key: client_auth_method: "client_secret_post" scopes: ["name", "email", "openid"] authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post +token_endpoint: https://appleid.apple.com/dummy_url_here user_mapping_provider: config: email_template: "{{ user.email }}" @@ -253,3 +255,35 @@ class PydanticOIDCTestCase(TestCase): del self.config["scopes"] model = OIDCProviderModel.parse_obj(self.config) self.assertEqual(model.scopes, ("openid",)) + + @parameterized.expand(["authorization_endpoint", "token_endpoint"]) + def test_endpoints_required_when_discovery_disabled(self, key: str) -> None: + # Test that this field is required if discovery is disabled + self.config["discover"] = False + with self.assertRaises(ValidationError): + self.config[key] = None + OIDCProviderModel.parse_obj(self.config) + with self.assertRaises(ValidationError): + del self.config[key] + OIDCProviderModel.parse_obj(self.config) + # We don't validate that the endpoint is a sensible URL; anything str will do + self.config[key] = "blahblah" + OIDCProviderModel.parse_obj(self.config) + + def check_all_cases_pass(): + self.config[key] = None + OIDCProviderModel.parse_obj(self.config) + + del self.config[key] + OIDCProviderModel.parse_obj(self.config) + + self.config[key] = "blahblah" + OIDCProviderModel.parse_obj(self.config) + + # With discovery enabled, all three cases are accepted. + self.config["discover"] = True + check_all_cases_pass() + + # If not specified, discovery is also on by default. + del self.config["discover"] + check_all_cases_pass()