Checked for null format in wrong place
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / idp / profile / saml2 / AbstractSAML2ProfileHandler.java
index a5c900b..bfdcfdc 100644 (file)
@@ -410,8 +410,13 @@ public abstract class AbstractSAML2ProfileHandler extends AbstractSAMLProfileHan
 
         ShibbolethSAMLAttributeRequestContext<NameID, AttributeQuery> queryContext;
 
+        if(requestContext.getSamlRequest() instanceof AttributeQuery){
         queryContext = new ShibbolethSAMLAttributeRequestContext<NameID, AttributeQuery>(getMetadataProvider(),
                 requestContext.getRelyingPartyConfiguration(), (AttributeQuery) requestContext.getSamlRequest());
+        }else{
+            queryContext = new ShibbolethSAMLAttributeRequestContext<NameID, AttributeQuery>(getMetadataProvider(),
+                    requestContext.getRelyingPartyConfiguration(), null);
+        }
         queryContext.setAttributeRequester(requestContext.getAssertingPartyId());
         queryContext.setPrincipalName(requestContext.getPrincipalName());
         queryContext.setProfileConfiguration(requestContext.getProfileConfiguration());
@@ -636,7 +641,7 @@ public abstract class AbstractSAML2ProfileHandler extends AbstractSAMLProfileHan
         String nameFormat = null;
         if (requestContext.getSamlRequest() instanceof AuthnRequest) {
             AuthnRequest authnRequest = (AuthnRequest) requestContext.getSamlRequest();
-            if (authnRequest.getNameIDPolicy() != null) {
+            if (authnRequest.getNameIDPolicy() != null && !DatatypeHelper.isEmpty(nameFormat)) {
                 nameFormat = authnRequest.getNameIDPolicy().getFormat();
                 if (assertingPartySupportedFormats.contains(nameFormat)) {
                     nameFormats.add(nameFormat);
@@ -725,6 +730,9 @@ public abstract class AbstractSAML2ProfileHandler extends AbstractSAMLProfileHan
         auditLogEntry.setRequestId(context.getSamlRequest().getID());
         auditLogEntry.setResponseBinding(context.getMessageEncoder().getBindingURI());
         auditLogEntry.setResponseId(context.getSamlResponse().getID());
+        if(context.getPrincipalAttributes() != null){
+            auditLogEntry.getReleasedAttributes().addAll(context.getPrincipalAttributes().keySet());
+        }
         getAduitLog().log(Level.CRITICAL, auditLogEntry);
     }