cidr-enabled ip addr authn handler
[java-idp.git] / src / edu / internet2 / middleware / shibboleth / idp / authn / impl / IPAddressHandler.java
1 /*
2  * Copyright [2006] [University Corporation for Advanced Internet Development, Inc.]
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 package edu.internet2.middleware.shibboleth.idp.authn.impl;
18
19 import java.net.InetAddress;
20 import java.net.UnknownHostException;
21 import java.util.List;
22 import java.util.BitSet;
23 import java.util.concurrent.CopyOnWriteArrayList;
24
25 import javax.servlet.RequestDispatcher;
26 import javax.servlet.http.HttpServletRequest;
27 import javax.servlet.http.HttpServletResponse;
28 import javax.servlet.ServletRequest;
29
30 import edu.internet2.middleware.shibboleth.idp.authn.AuthenticationHandler;
31 import edu.internet2.middleware.shibboleth.idp.authn.LoginContext;
32 import java.net.Inet4Address;
33 import java.net.Inet6Address;
34
35 import org.apache.log4j.Logger;
36
37 import org.joda.time.DateTime;
38
39 /**
40  * IP Address authentication handler.
41  * 
42  * This "authenticates" a user based on their IP address. It operates in either
43  * default deny or default allow mode, and evaluates a given request against a
44  * list of blocked or permitted IPs. It supports both IPv4 and IPv6.
45  */
46 public class IPAddressHandler implements AuthenticationHandler {
47
48         /**
49          * Encapsulates a network address and a netmask on ipList.
50          */
51         protected class IPEntry {
52
53                 /** The network address. */
54                 private final BitSet networkAddress;
55
56                 /** The netmask. */
57                 private final BitSet netmask;
58
59                 /**
60                  * Construct a new IPEntry given a network address in CIDR format.
61                  * 
62                  * @param entry
63                  *            A CIDR-formatted network address/netmask
64                  * 
65                  * @throws UnknownHostException
66                  *             If entry is malformed.
67                  */
68                 public IPEntry(String entry) throws UnknownHostException {
69
70                         // quick sanity checks
71                         if (entry == null || entry.length() == 0) {
72                                 throw new UnknownHostException("entry is null.");
73                         }
74
75                         int cidrOffset = entry.indexOf("/");
76                         if (cidrOffset == -1) {
77                                 log.error("IPAddressHandler: invalid entry \"" + entry
78                                                 + "\" -- it lacks a netmask component.");
79                                 throw new UnknownHostException(
80                                                 "entry lacks a netmask component.");
81                         }
82
83                         // ensure that only one "/" is present.
84                         if (entry.indexOf("/", cidrOffset + 1) != -1) {
85                                 log.error("IPAddressHandler: invalid entry \"" + entry
86                                                 + "\" -- too many \"/\" present.");
87                                 throw new UnknownHostException(
88                                                 "entry has too many netmask components.");
89                         }
90
91                         String networkString = entry.substring(0, cidrOffset);
92                         String netmaskString = entry.substring(cidrOffset + 1, entry
93                                         .length());
94
95                         InetAddress tempAddr = InetAddress.getByName(networkString);
96                         networkAddress = byteArrayToBitSet(tempAddr.getAddress());
97
98                         int masklen = Integer.parseInt(netmaskString);
99                         int addrlen = networkAddress.length();
100
101                         // ensure that the netmask isn't too large
102                         if ((tempAddr instanceof Inet4Address) && (masklen > 32)) {
103                                 throw new UnknownHostException(
104                                                 "IPAddressHandler: Netmask is too large for an IPv4 address: "
105                                                                 + masklen);
106                         } else if ((tempAddr instanceof Inet6Address) && masklen > 128) {
107                                 throw new UnknownHostException(
108                                                 "IPAddressHandler: Netmask is too large for an IPv6 address: "
109                                                                 + masklen);
110                         }
111
112                         netmask = new BitSet(addrlen);
113                         netmask.set(addrlen - masklen, addrlen, true);
114                 }
115
116                 /**
117                  * Get the network address.
118                  * 
119                  * @return the network address.
120                  */
121                 public BitSet getNetworkAddress() {
122                         return networkAddress;
123                 }
124
125                 /**
126                  * Get the netmask.
127                  * 
128                  * @return the netmask.
129                  */
130                 public BitSet getNetmask() {
131                         return netmask;
132                 }
133         }
134
135         private static final Logger log = Logger.getLogger(IPAddressHandler.class);
136
137         /** the URI of the AuthnContextDeclRef or the AuthnContextClass */
138         private String authnMethodURI;
139
140         /** The return location */
141         private String returnLocation;
142
143         /** Are the IPs in ipList a permitted list or a deny list */
144         private boolean defaultDeny;
145
146         /** The list of denied or permitted IPs */
147         private List<IPEntry> ipList;
148
149         /** Creates a new instance of IPAddressHandler */
150         public IPAddressHandler() {
151         }
152
153         /**
154          * Set the permitted IP addresses.
155          * 
156          * If <code>defaultDeny</code> is <code>true</code> then only the IP
157          * addresses in <code>ipList</code> will be "authenticated." If
158          * <code>defaultDeny</code> is <code>false</code>, then all IP
159          * addresses except those in <code>ipList</code> will be authenticated.
160          * 
161          * @param entries
162          *            A list of IP addresses (with CIDR masks).
163          * @param defaultDeny
164          *            Does <code>ipList</code> contain a deny or permit list.
165          */
166         public void setEntries(final List<String> entries, boolean defaultDeny) {
167
168                 this.defaultDeny = defaultDeny;
169                 ipList = new CopyOnWriteArrayList<IPEntry>();
170
171                 for (String addr : entries) {
172                         try {
173                                 ipList
174                                                 .add(new edu.internet2.middleware.shibboleth.idp.authn.impl.IPAddressHandler.IPEntry(
175                                                                 addr));
176                         } catch (UnknownHostException ex) {
177                                 log.error("IPAddressHandler: Error parsing entry \"" + addr
178                                                 + "\". Ignoring.");
179                         }
180                 }
181         }
182
183         /** {@inheritDoc    */
184         public void setReturnLocation(String location) {
185                 this.returnLocation = location;
186         }
187
188         /** @{inheritDoc} */
189         public boolean supportsPassive() {
190                 return true;
191         }
192
193         /** {@inheritDoc} */
194         public boolean supportsForceAuthentication() {
195                 return true;
196         }
197
198         /** {@inheritDoc} */
199         public void logout(final HttpServletRequest request,
200                         final HttpServletResponse response, final String principal) {
201
202                 RequestDispatcher dispatcher = request
203                                 .getRequestDispatcher(returnLocation);
204                 // dispatcher.forward(request, response);
205         }
206
207         /** {@inheritDoc} */
208         public void login(final HttpServletRequest request,
209                         final HttpServletResponse response, final LoginContext loginCtx) {
210
211                 loginCtx.setAuthenticationAttempted();
212                 loginCtx.setAuthenticationInstant(new DateTime());
213
214                 if (defaultDeny) {
215                         handleDefaultDeny(request, response, loginCtx);
216                 } else {
217                         handleDefaultAllow(request, response, loginCtx);
218                 }
219         }
220
221         protected void handleDefaultDeny(HttpServletRequest request,
222                         HttpServletResponse response, LoginContext loginCtx) {
223
224                 boolean ipAllowed = searchIpList(request);
225
226                 if (ipAllowed) {
227                         loginCtx.setAuthenticationOK(true);
228                 } else {
229                         loginCtx.setAuthenticationOK(false);
230                         loginCtx
231                                         .setAuthenticationFailureMessage("User's IP is not in the permitted list.");
232                 }
233         }
234
235         protected void handleDefaultAllow(HttpServletRequest request,
236                         HttpServletResponse response, LoginContext loginCtx) {
237
238                 boolean ipDenied = searchIpList(request);
239
240                 if (ipDenied) {
241                         loginCtx.setAuthenticationOK(false);
242                         loginCtx
243                                         .setAuthenticationFailureMessage("Users's IP is in the deny list.");
244                 } else {
245                         loginCtx.setAuthenticationOK(true);
246                 }
247         }
248
249         /**
250          * Search the list of InetAddresses for the client's address.
251          * 
252          * @param request
253          *            The ServletReqeust
254          * 
255          * @return <code>true</code> if the client's address is in
256          *         <code>this.ipList</code>
257          */
258         private boolean searchIpList(final ServletRequest request) {
259
260                 boolean found = false;
261
262                 try {
263                         InetAddress addr = InetAddress.getByName(request.getRemoteAddr());
264                         BitSet addrbits = byteArrayToBitSet(addr.getAddress());
265
266                         for (IPEntry entry : ipList) {
267
268                                 BitSet netaddr = entry.getNetworkAddress();
269                                 BitSet netmask = entry.getNetmask();
270
271                                 addrbits.and(netmask);
272                                 if (addrbits.equals(netaddr)) {
273                                         found = true;
274                                         break;
275                                 }
276                         }
277
278                 } catch (UnknownHostException ex) {
279                         log.error("Error resolving hostname: ", ex);
280                         return false;
281                 }
282
283                 return found;
284         }
285
286         /**
287          * Converts a byte array to a BitSet.
288          * 
289          * The supplied byte array is assumed to have the most signifigant bit in
290          * element 0.
291          * 
292          * @param bytes
293          *            the byte array with most signifigant bit in element 0.
294          * 
295          * @return the BitSet
296          */
297         protected static BitSet byteArrayToBitSet(final byte[] bytes) {
298
299                 BitSet bits = new BitSet();
300
301                 for (int i = 0; i < bytes.length * 8; i++) {
302                         if ((bytes[bytes.length - i / 8 - 1] & (1 << (i % 8))) > 0) {
303                                 bits.set(i);
304                         }
305                 }
306
307                 return bits;
308         }
309 }