Fixed ClassCastException when using Default Relying Party.
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / common / ServiceProviderMapper.java
1 /*
2  * The Shibboleth License, Version 1. Copyright (c) 2002 University Corporation
3  * for Advanced Internet Development, Inc. All rights reserved
4  * 
5  * 
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  * 
9  * Redistributions of source code must retain the above copyright notice, this
10  * list of conditions and the following disclaimer.
11  * 
12  * Redistributions in binary form must reproduce the above copyright notice,
13  * this list of conditions and the following disclaimer in the documentation
14  * and/or other materials provided with the distribution, if any, must include
15  * the following acknowledgment: "This product includes software developed by
16  * the University Corporation for Advanced Internet Development
17  * <http://www.ucaid.edu> Internet2 Project. Alternately, this acknowledegement
18  * may appear in the software itself, if and wherever such third-party
19  * acknowledgments normally appear.
20  * 
21  * Neither the name of Shibboleth nor the names of its contributors, nor
22  * Internet2, nor the University Corporation for Advanced Internet Development,
23  * Inc., nor UCAID may be used to endorse or promote products derived from this
24  * software without specific prior written permission. For written permission,
25  * please contact shibboleth@shibboleth.org
26  * 
27  * Products derived from this software may not be called Shibboleth, Internet2,
28  * UCAID, or the University Corporation for Advanced Internet Development, nor
29  * may Shibboleth appear in their name, without prior written permission of the
30  * University Corporation for Advanced Internet Development.
31  * 
32  * 
33  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
34  * AND WITH ALL FAULTS. ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
35  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
36  * PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE DISCLAIMED AND THE ENTIRE RISK
37  * OF SATISFACTORY QUALITY, PERFORMANCE, ACCURACY, AND EFFORT IS WITH LICENSEE.
38  * IN NO EVENT SHALL THE COPYRIGHT OWNER, CONTRIBUTORS OR THE UNIVERSITY
39  * CORPORATION FOR ADVANCED INTERNET DEVELOPMENT, INC. BE LIABLE FOR ANY
40  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
41  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
42  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
43  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
44  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
45  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
46  */
47 package edu.internet2.middleware.shibboleth.common;
48
49 import java.net.URI;
50 import java.net.URL;
51 import java.util.HashMap;
52 import java.util.Map;
53
54 import org.apache.log4j.Logger;
55 import org.w3c.dom.Element;
56
57 import edu.internet2.middleware.shibboleth.aa.AARelyingParty;
58 import edu.internet2.middleware.shibboleth.hs.HSRelyingParty;
59
60 /**
61  * @author Walter Hoehn
62  *  
63  */
64 public abstract class ServiceProviderMapper {
65
66         private static Logger log = Logger.getLogger(ServiceProviderMapper.class.getName());
67         protected Map relyingParties = new HashMap();
68
69         protected abstract ShibbolethOriginConfig getOriginConfig();
70
71         protected void verifyDefaultParty(ShibbolethOriginConfig configuration) throws ServiceProviderMapperException {
72                 //Verify we have a proper default party
73                 String defaultParty = configuration.getDefaultRelyingPartyName();
74                 if (defaultParty == null || defaultParty.equals("")) {
75                         if (relyingParties.size() != 1) {
76                                 log.error(
77                                         "Default Relying Party not specified.  Add a (defaultRelyingParty) attribute to <ShibbolethOriginConfig>.");
78                                 throw new ServiceProviderMapperException("Required configuration not specified.");
79                         } else {
80                                 log.debug("Only one Relying Party loaded.  Using this as the default.");
81                         }
82                 }
83                 log.debug("Default Relying Party set to: (" + defaultParty + ").");
84                 if (!relyingParties.containsKey(defaultParty)) {
85                         log.error("Default Relying Party refers to a Relying Party that has not been loaded.");
86                         throw new ServiceProviderMapperException("Invalid configuration (Default Relying Party).");
87                 }
88         }
89
90         protected RelyingParty getRelyingPartyImpl(String providerIdFromTarget) {
91
92                 //Look for a configuration for the specific relying party
93                 if (relyingParties.containsKey(providerIdFromTarget)) {
94                         log.info("Found Relying Party for (" + providerIdFromTarget + ").");
95                         return (RelyingParty) relyingParties.get(providerIdFromTarget);
96                 }
97
98                 //Next, check to see if the relying party is in any groups
99                 RelyingParty groupParty = findRelyingPartyByGroup(providerIdFromTarget);
100                 if (groupParty != null) {
101                         log.info("Provider is a member of Relying Party (" + groupParty.getName() + ").");
102                         return new RelyingPartyGroupWrapper(groupParty, providerIdFromTarget);
103                 }
104
105                 //OK, just send the default
106                 log.info(
107                         "Could not locate Relying Party configuration for ("
108                                 + providerIdFromTarget
109                                 + ").  Using default Relying Party.");
110                 return new UnknownProviderWrapper(getDefaultRelyingPatry());
111         }
112
113         private RelyingParty findRelyingPartyByGroup(String providerIdFromTarget) {
114
115                 // TODO This is totally a stub and needs to be based on target metadata
116                 // lookup
117                 if (providerIdFromTarget.startsWith("urn:mace:inqueue:")) {
118                         if (relyingParties.containsKey("urn:mace:inqueue")) {
119                                 return (RelyingParty) relyingParties.get("urn:mace:inqueue");
120                         }
121                 }
122                 return null;
123         }
124
125         protected RelyingParty getDefaultRelyingPatry() {
126
127                 //If there is no explicit default, pick the single configured Relying
128                 // Party
129                 String defaultParty = getOriginConfig().getDefaultRelyingPartyName();
130                 if (defaultParty == null || defaultParty.equals("")) {
131                         return (RelyingParty) relyingParties.values().iterator().next();
132                 }
133
134                 //If we do have a default specified, use it...
135                 return (RelyingParty) relyingParties.get(defaultParty);
136         }
137
138         protected abstract class BaseRelyingPartyImpl implements RelyingParty {
139
140                 protected RelyingPartyIdentityProvider identityProvider;
141                 protected String name;
142                 protected String overridenOriginProviderId;
143
144                 /**
145                  * Shared construction
146                  */
147                 public BaseRelyingPartyImpl(Element partyConfig) throws ServiceProviderMapperException {
148
149                         //Get party name
150                         name = ((Element) partyConfig).getAttribute("name");
151                         if (name == null || name.equals("")) {
152                                 log.error("Relying Party name not set.  Add a (name) attribute to <RelyingParty>.");
153                                 throw new ServiceProviderMapperException("Required configuration not specified.");
154                         }
155                         log.debug("Loading Relying Party: (" + name + ").");
156
157                         //Process overrides for global data
158                         String attribute = ((Element) partyConfig).getAttribute("providerId");
159                         if (attribute != null && !attribute.equals("")) {
160                                 log.debug("Overriding providerId for Relying Pary (" + name + ") with (" + attribute + ").");
161                                 overridenOriginProviderId = attribute;
162                         }
163
164                 }
165
166                 public String getProviderId() {
167                         return name;
168                 }
169
170                 public String getName() {
171                         return name;
172                 }
173
174                 public IdentityProvider getIdentityProvider() {
175                         return identityProvider;
176                 }
177
178                 protected class RelyingPartyIdentityProvider implements IdentityProvider {
179
180                         private String providerId;
181                         private Credential responseSigningCredential;
182
183                         public RelyingPartyIdentityProvider(String providerId, Credential responseSigningCred) {
184                                 this.providerId = providerId;
185                                 this.responseSigningCredential = responseSigningCred;
186                         }
187
188                         public String getProviderId() {
189                                 return providerId;
190                         }
191
192                         public Credential getResponseSigningCredential() {
193                                 return responseSigningCredential;
194                         }
195
196                         public Credential getAssertionSigningCredential() {
197                                 return null;
198                         }
199
200                 }
201         }
202
203         class RelyingPartyGroupWrapper implements RelyingParty, HSRelyingParty, AARelyingParty {
204
205                 private RelyingParty wrapped;
206                 private String providerId;
207
208                 RelyingPartyGroupWrapper(RelyingParty wrapped, String providerId) {
209                         this.wrapped = wrapped;
210                         this.providerId = providerId;
211                 }
212
213                 public String getName() {
214                         return wrapped.getName();
215                 }
216
217                 public boolean isLegacyProvider() {
218                         return false;
219                 }
220
221                 public IdentityProvider getIdentityProvider() {
222                         return wrapped.getIdentityProvider();
223                 }
224
225                 public String getProviderId() {
226                         return providerId;
227                 }
228                 public String getHSNameFormatId() {
229                         if (!(wrapped instanceof HSRelyingParty)) {
230                                 return null;
231                         }
232                         return ((HSRelyingParty) wrapped).getHSNameFormatId();
233                 }
234
235                 public URL getAAUrl() {
236                         if (!(wrapped instanceof HSRelyingParty)) {
237                                 return null;
238                         }
239                         return ((HSRelyingParty) wrapped).getAAUrl();
240                 }
241
242                 public URI getDefaultAuthMethod() {
243                         if (!(wrapped instanceof HSRelyingParty)) {
244                                 return null;
245                         }
246                         return ((HSRelyingParty) wrapped).getDefaultAuthMethod();
247                 }
248
249                 public boolean passThruErrors() {
250                         if (!(wrapped instanceof AARelyingParty)) {
251                                 return false;
252                         }
253                         return ((AARelyingParty) wrapped).passThruErrors();
254                 }
255         }
256
257         protected class UnknownProviderWrapper implements RelyingParty, HSRelyingParty, AARelyingParty {
258                 protected RelyingParty wrapped;
259
260                 protected UnknownProviderWrapper(RelyingParty wrapped) {
261                         this.wrapped = wrapped;
262                 }
263
264                 public String getName() {
265                         return wrapped.getName();
266                 }
267
268                 public IdentityProvider getIdentityProvider() {
269                         return wrapped.getIdentityProvider();
270                 }
271
272                 public String getProviderId() {
273                         return null;
274                 }
275
276                 public String getHSNameFormatId() {
277                         if (!(wrapped instanceof HSRelyingParty)) {
278                                 return null;
279                         }
280                         return ((HSRelyingParty) wrapped).getHSNameFormatId();
281                 }
282
283                 public boolean isLegacyProvider() {
284                         if (!(wrapped instanceof HSRelyingParty)) {
285                                 return false;
286                         }
287                         return ((HSRelyingParty) wrapped).isLegacyProvider();
288                 }
289
290                 public URL getAAUrl() {
291                         if (!(wrapped instanceof HSRelyingParty)) {
292                                 return null;
293                         }
294                         return ((HSRelyingParty) wrapped).getAAUrl();
295                 }
296
297                 public URI getDefaultAuthMethod() {
298                         if (!(wrapped instanceof HSRelyingParty)) {
299                                 return null;
300                         }
301                         return ((HSRelyingParty) wrapped).getDefaultAuthMethod();
302                 }
303
304                 public boolean passThruErrors() {
305                         if (!(wrapped instanceof AARelyingParty)) {
306                                 return false;
307                         }
308                         return ((AARelyingParty) wrapped).passThruErrors();
309                 }
310         }
311
312 }