Deal with selecting relying party endpoint sooner and populating subject confirmation...
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / idp / profile / saml2 / SSOProfileHandler.java
index b10d28a..9dcea39 100644 (file)
@@ -36,6 +36,7 @@ import org.opensaml.common.binding.security.SAMLSecurityPolicy;
 import org.opensaml.common.xml.SAMLConstants;
 import org.opensaml.saml2.core.SubjectLocality;
 import org.opensaml.saml2.binding.AuthnResponseEndpointSelector;
 import org.opensaml.common.xml.SAMLConstants;
 import org.opensaml.saml2.core.SubjectLocality;
 import org.opensaml.saml2.binding.AuthnResponseEndpointSelector;
+import org.opensaml.saml2.core.AttributeStatement;
 import org.opensaml.saml2.core.AuthnContext;
 import org.opensaml.saml2.core.AuthnContextClassRef;
 import org.opensaml.saml2.core.AuthnContextDeclRef;
 import org.opensaml.saml2.core.AuthnContext;
 import org.opensaml.saml2.core.AuthnContextClassRef;
 import org.opensaml.saml2.core.AuthnContextDeclRef;
@@ -226,15 +227,15 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
                 throw new ProfileException("User failed authentication");
             }
 
                 throw new ProfileException("User failed authentication");
             }
 
+            AuthnStatement authnStatement = buildAuthnStatement(requestContext);
+            AttributeStatement attributeStatement = buildAttributeStatement(requestContext);
+            
             ArrayList<Statement> statements = new ArrayList<Statement>();
             ArrayList<Statement> statements = new ArrayList<Statement>();
-            statements.add(buildAuthnStatement(requestContext));
-            if (requestContext.getProfileConfiguration().includeAttributeStatement()) {
-                statements.add(buildAttributeStatement(requestContext));
-            }
-
-            Subject assertionSubject = buildSubject(requestContext, "urn:oasis:names:tc:SAML:2.0:cm:bearer");
+            statements.add(authnStatement);
+            //TODO optional include this
+            statements.add(attributeStatement);
 
 
-            samlResponse = buildResponse(requestContext, assertionSubject, statements);
+            samlResponse = buildResponse(requestContext, "urn:oasis:names:tc:SAML:2.0:cm:bearer", statements);
         } catch (ProfileException e) {
             samlResponse = buildErrorResponse(requestContext);
         }
         } catch (ProfileException e) {
             samlResponse = buildErrorResponse(requestContext);
         }
@@ -322,6 +323,8 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
                     .getProfileConfiguration(SSOConfiguration.PROFILE_ID));
 
             requestContext.setSamlRequest(authnRequest);
                     .getProfileConfiguration(SSOConfiguration.PROFILE_ID));
 
             requestContext.setSamlRequest(authnRequest);
+            
+            selectEndpoint(requestContext);
 
             return requestContext;
         } catch (UnmarshallingException e) {
 
             return requestContext;
         } catch (UnmarshallingException e) {
@@ -422,6 +425,22 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
 
         return subjectLocality;
     }
 
         return subjectLocality;
     }
+    
+    /**
+     * Selects the appropriate endpoint for the relying party and stores it in the request context.
+     * 
+     * @param requestContext current request context
+     */
+    protected void selectEndpoint(SSORequestContext requestContext){
+        AuthnResponseEndpointSelector endpointSelector = new AuthnResponseEndpointSelector();
+        endpointSelector.setEndpointType(AssertionConsumerService.DEFAULT_ELEMENT_NAME);
+        endpointSelector.setMetadataProvider(getMetadataProvider());
+        endpointSelector.setRelyingParty(requestContext.getRelyingPartyMetadata());
+        endpointSelector.setRelyingPartyRole(requestContext.getRelyingPartyRoleMetadata());
+        endpointSelector.setSamlRequest(requestContext.getSamlRequest());
+        endpointSelector.getSupportedIssuerBindings().addAll(supportedOutgoingBindings);
+        requestContext.setRelyingPartyEndpoint(endpointSelector.selectEndpoint());
+    }
 
     /**
      * Encodes the request's SAML response and writes it to the servlet response.
 
     /**
      * Encodes the request's SAML response and writes it to the servlet response.
@@ -435,15 +454,8 @@ public class SSOProfileHandler extends AbstractSAML2ProfileHandler {
             log.debug("Encoding response to SAML request " + requestContext.getSamlRequest().getID()
                     + " from relying party " + requestContext.getRelyingPartyId());
         }
             log.debug("Encoding response to SAML request " + requestContext.getSamlRequest().getID()
                     + " from relying party " + requestContext.getRelyingPartyId());
         }
-        AuthnResponseEndpointSelector endpointSelector = new AuthnResponseEndpointSelector();
-        endpointSelector.setEndpointType(AssertionConsumerService.DEFAULT_ELEMENT_NAME);
-        endpointSelector.setMetadataProvider(getMetadataProvider());
-        endpointSelector.setRelyingParty(requestContext.getRelyingPartyMetadata());
-        endpointSelector.setRelyingPartyRole(requestContext.getRelyingPartyRoleMetadata());
-        endpointSelector.setSamlRequest(requestContext.getSamlRequest());
-        endpointSelector.getSupportedIssuerBindings().addAll(supportedOutgoingBindings);
-        Endpoint relyingPartyEndpoint = endpointSelector.selectEndpoint();
 
 
+        Endpoint relyingPartyEndpoint = requestContext.getRelyingPartyEndpoint();
         MessageEncoder<ServletResponse> encoder = getMessageEncoderFactory().getMessageEncoder(
                 relyingPartyEndpoint.getBinding());
         if (encoder == null) {
         MessageEncoder<ServletResponse> encoder = getMessageEncoderFactory().getMessageEncoder(
                 relyingPartyEndpoint.getBinding());
         if (encoder == null) {