Set communication profile before decoding
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / idp / profile / saml1 / ShibbolethSSOProfileHandler.java
index a9c2e07..da1ab1f 100644 (file)
@@ -72,6 +72,9 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
     /** Builder of SubjectLocality objects. */
     private SAMLObjectBuilder<SubjectLocality> subjectLocalityBuilder;
 
+    /** Builder of Endpoint objects. */
+    private SAMLObjectBuilder<Endpoint> endpointBuilder;
+
     /** URL of the authentication manager servlet. */
     private String authenticationManagerPath;
 
@@ -94,6 +97,9 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
 
         subjectLocalityBuilder = (SAMLObjectBuilder<SubjectLocality>) getBuilderFactory().getBuilder(
                 SubjectLocality.DEFAULT_ELEMENT_NAME);
+
+        endpointBuilder = (SAMLObjectBuilder<Endpoint>) getBuilderFactory().getBuilder(
+                AssertionConsumerService.DEFAULT_ELEMENT_NAME);
     }
 
     /** {@inheritDoc} */
@@ -177,6 +183,8 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
         HttpServletRequest httpRequest = ((HttpServletRequestAdapter) inTransport).getWrappedRequest();
 
         ShibbolethSSORequestContext requestContext = new ShibbolethSSORequestContext();
+        requestContext.setCommunicationProfileId(getProfileId());
+        
         requestContext.setMetadataProvider(getMetadataProvider());
         requestContext.setSecurityPolicyResolver(getSecurityPolicyResolver());
 
@@ -206,7 +214,6 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
         loginContext.setSpTarget(requestContext.getRelayState());
         loginContext.setAuthenticationEngineURL(authenticationManagerPath);
         loginContext.setProfileHandlerURL(HttpHelper.getRequestUriWithoutContext(httpRequest));
-
         requestContext.setLoginContext(loginContext);
 
         return requestContext;
@@ -277,6 +284,7 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
     protected ShibbolethSSORequestContext buildRequestContext(ShibbolethSSOLoginContext loginContext,
             HTTPInTransport in, HTTPOutTransport out) throws ProfileException {
         ShibbolethSSORequestContext requestContext = new ShibbolethSSORequestContext();
+        requestContext.setCommunicationProfileId(getProfileId());
 
         requestContext.setMessageDecoder(getMessageDecoders().get(getInboundBinding()));
 
@@ -292,6 +300,7 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
         requestContext.setMetadataProvider(getMetadataProvider());
 
         String relyingPartyId = loginContext.getRelyingPartyId();
+        requestContext.setPeerEntityId(relyingPartyId);
         requestContext.setInboundMessageIssuer(relyingPartyId);
 
         populateRequestContext(requestContext);
@@ -325,8 +334,7 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
     }
 
     /** {@inheritDoc} */
-    protected void populateSAMLMessageInformation(BaseSAMLProfileRequestContext requestContext) 
-        throws ProfileException {
+    protected void populateSAMLMessageInformation(BaseSAMLProfileRequestContext requestContext) throws ProfileException {
         // nothing to do here
     }
 
@@ -349,7 +357,16 @@ public class ShibbolethSSOProfileHandler extends AbstractSAML1ProfileHandler {
         endpointSelector.setSamlRequest(requestContext.getInboundSAMLMessage());
         endpointSelector.getSupportedIssuerBindings().addAll(getSupportedOutboundBindings());
 
-        return endpointSelector.selectEndpoint();
+        Endpoint endpoint = endpointSelector.selectEndpoint();
+        if (endpoint == null && loginContext.getSpAssertionConsumerService() != null) {
+            endpoint = endpointBuilder.buildObject();
+            endpoint.setLocation(loginContext.getSpAssertionConsumerService());
+            endpoint.setBinding(getSupportedOutboundBindings().get(0));
+            log.warn("No endpoint available for relying party {}. Generating endpoint with ACS url {} and binding {}",
+                    new Object[] { requestContext.getPeerEntityId(), endpoint.getLocation(), endpoint.getBinding() });
+        }
+
+        return endpoint;
     }
 
     /**