001    /*
002     * Copyright (c) 2007 Mozilla Foundation
003     *
004     * Permission is hereby granted, free of charge, to any person obtaining a 
005     * copy of this software and associated documentation files (the "Software"), 
006     * to deal in the Software without restriction, including without limitation 
007     * the rights to use, copy, modify, merge, publish, distribute, sublicense, 
008     * and/or sell copies of the Software, and to permit persons to whom the 
009     * Software is furnished to do so, subject to the following conditions:
010     *
011     * The above copyright notice and this permission notice shall be included in 
012     * all copies or substantial portions of the Software.
013     *
014     * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 
015     * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 
016     * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 
017     * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 
018     * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 
019     * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 
020     * DEALINGS IN THE SOFTWARE.
021     */
022    
023    package nu.validator.servlet;
024    
025    import java.io.BufferedReader;
026    import java.io.ByteArrayInputStream;
027    import java.io.IOException;
028    import java.io.InputStream;
029    import java.io.InputStreamReader;
030    import java.io.Reader;
031    import java.nio.charset.CharacterCodingException;
032    import java.nio.charset.Charset;
033    import java.nio.charset.CharsetDecoder;
034    import java.nio.charset.CharsetEncoder;
035    import java.nio.charset.CodingErrorAction;
036    import java.util.ArrayList;
037    import java.util.Collection;
038    import java.util.Collections;
039    import java.util.Enumeration;
040    import java.util.HashMap;
041    import java.util.List;
042    import java.util.Map;
043    import java.util.regex.Matcher;
044    import java.util.regex.Pattern;
045    
046    import javax.servlet.Filter;
047    import javax.servlet.FilterChain;
048    import javax.servlet.FilterConfig;
049    import javax.servlet.ServletException;
050    import javax.servlet.ServletInputStream;
051    import javax.servlet.ServletRequest;
052    import javax.servlet.ServletResponse;
053    import javax.servlet.http.HttpServletRequest;
054    import javax.servlet.http.HttpServletRequestWrapper;
055    import javax.servlet.http.HttpServletResponse;
056    
057    import nu.validator.servletfilter.DelegatingServletInputStream;
058    
059    import org.apache.commons.fileupload.FileItemIterator;
060    import org.apache.commons.fileupload.FileItemStream;
061    import org.apache.commons.fileupload.FileUploadException;
062    import org.apache.commons.fileupload.servlet.ServletFileUpload;
063    
064    public class MultipartFormDataFilter implements Filter {
065    
066        private static Pattern EXTENSION = Pattern.compile("^.*\\.(.+)$");
067        
068        private static Map<String, String> extensionToType = new HashMap<String, String>();
069        
070        static {
071            extensionToType.put("html", "text/html");
072            extensionToType.put("htm", "text/html");
073            extensionToType.put("xhtml", "application/xhtml+xml");
074            extensionToType.put("xht", "application/xhtml+xml");
075            extensionToType.put("atom", "application/atom+xml");
076            extensionToType.put("rng", "application/xml");
077            extensionToType.put("xsl", "application/xml");
078            extensionToType.put("xml", "application/xml");
079            extensionToType.put("dbk", "application/xml");
080        }
081        
082        private static String utf8ByteStreamToString(InputStream stream) throws IOException {
083            CharsetDecoder dec = Charset.forName("UTF-8").newDecoder();
084            dec.onMalformedInput(CodingErrorAction.REPORT);
085            dec.onUnmappableCharacter(CodingErrorAction.REPORT);
086            Reader reader = new InputStreamReader(stream, dec);
087            StringBuilder builder = new StringBuilder();
088            int c;
089            int i = 0;
090            while ((c = reader.read()) != -1) {
091                if (i > 2048) {
092                    throw new IOException("Form field value too large.");
093                }
094                builder.append((char)c);
095                i++;
096            }
097            return builder.toString();
098        }
099    
100        private static void putParam(Map<String, String[]> params, String key, String value) {
101            String[] oldVal = params.get(key);
102            if (oldVal == null) {
103                String[] arr = new String[1];
104                arr[0] = value;
105                params.put(key, arr);
106            } else {
107                for (int i = 0; i < oldVal.length; i++) {
108                    String string = oldVal[i];
109                    if (string.equals(value)) {
110                        return;
111                    }
112                }
113                String[] arr = new String[oldVal.length + 1];
114                System.arraycopy(oldVal, 0, arr, 0, oldVal.length);
115                arr[oldVal.length] = value;
116                params.put(key, arr);            
117            }
118        }
119        
120        public void destroy() {
121        }
122    
123        public void doFilter(ServletRequest req, ServletResponse res,
124                FilterChain chain) throws IOException, ServletException {
125            HttpServletRequest request = (HttpServletRequest) req;
126            HttpServletResponse response = (HttpServletResponse) res;
127            if (ServletFileUpload.isMultipartContent(request)) {
128                try {
129                    boolean utf8 = false;
130                    String contentType = null;
131                    Map<String, String[]> params = new HashMap<String, String[]>();
132                    InputStream fileStream = null;
133                ServletFileUpload upload = new ServletFileUpload();
134                FileItemIterator iter = upload.getItemIterator(request);
135                while (iter.hasNext()) {
136                    FileItemStream fileItemStream = iter.next();
137                    if (fileItemStream.isFormField()) {
138                        String fieldName = fileItemStream.getFieldName();
139                        if ("content".equals(fieldName)) {
140                            utf8 = true;
141                            String[] parser = params.get("parser");
142                            if (parser != null && parser[0].startsWith("xml")) {
143                                contentType = "application/xml";
144                            } else {
145                                contentType = "text/html";
146                            }
147                            fileStream = fileItemStream.openStream();                      
148                            break;
149                        } else {
150                            putParam(params, fieldName, utf8ByteStreamToString(fileItemStream.openStream()));
151                        }
152                    } else {
153                        String fileName = fileItemStream.getName();
154                        if (fileName != null) {
155                            putParam(params,  fileItemStream.getFieldName(), fileName); 
156                            Matcher m = EXTENSION.matcher(fileName);
157                            if (m.matches()) {
158                                contentType = extensionToType.get(m.group(1));
159                            }
160                        }
161                        if (contentType == null) {
162                            contentType = fileItemStream.getContentType();      
163                        }
164                        fileStream = fileItemStream.openStream();
165                        break;
166                    }
167                }
168                if (fileStream == null) {
169                    fileStream = new ByteArrayInputStream(new byte[0]);
170                }
171                chain.doFilter(new RequestWrapper(request, params, contentType, utf8, fileStream), response);
172                } catch (FileUploadException e) {
173                    response.sendError(415, e.getMessage());
174                } catch (CharacterCodingException e) {
175                    response.sendError(415, e.getMessage());                
176                } catch (IOException e) {
177                    response.sendError(HttpServletResponse.SC_BAD_REQUEST, e.getMessage());                
178                }
179            } else {
180                chain.doFilter(req, res);
181            }
182        }
183    
184        public void init(FilterConfig arg0) throws ServletException {
185        }
186    
187        private class RequestWrapper extends HttpServletRequestWrapper {
188    
189            private final Map<String, String[]> params;
190            private final String contentType;
191            private final boolean utf8;
192            private final ServletInputStream stream;
193            
194            public RequestWrapper(HttpServletRequest req, Map<String, String[]> params, String contentType, boolean utf8, InputStream stream) {
195                super(req);
196                this.params = Collections.unmodifiableMap(params);
197                this.contentType = contentType;
198                this.utf8 = utf8;
199                this.stream = new DelegatingServletInputStream(stream);
200            }
201    
202            /**
203             * @see javax.servlet.http.HttpServletRequestWrapper#getDateHeader(java.lang.String)
204             */
205            @Override
206            public long getDateHeader(String name) {
207                if ("Content-Length".equalsIgnoreCase(name)) {
208                    return -1;
209                } else if ("Content-MD5".equalsIgnoreCase(name)) {
210                    return -1;
211                } else if ("Content-Encoding".equalsIgnoreCase(name)) {
212                    return -1;
213                } else if ("Content-Type".equalsIgnoreCase(name)) {
214                    return -1;
215                } else {
216                    return super.getDateHeader(name);
217                }
218            }
219    
220            /**
221             * @see javax.servlet.http.HttpServletRequestWrapper#getHeader(java.lang.String)
222             */
223            @Override
224            public String getHeader(String name) {
225                if ("Content-Length".equalsIgnoreCase(name)) {
226                    return null;
227                } else if ("Content-MD5".equalsIgnoreCase(name)) {
228                    return null;
229                } else if ("Content-Encoding".equalsIgnoreCase(name)) {
230                    return null;
231                } else if ("Content-Type".equalsIgnoreCase(name)) {
232                    return getContentType();
233                } else {
234                    return super.getHeader(name);
235                }
236            }
237    
238            /**
239             * @see javax.servlet.http.HttpServletRequestWrapper#getHeaderNames()
240             */
241            @Override
242            public Enumeration getHeaderNames() {
243                Enumeration e = super.getHeaderNames();
244                List<String> list = new ArrayList<String>();
245                while (e.hasMoreElements()) {
246                    String name = (String) e.nextElement();
247                    if ("Content-Length".equalsIgnoreCase(name)) {
248                        continue;
249                    } else if ("Content-MD5".equalsIgnoreCase(name)) {
250                        continue;
251                    } else if ("Content-Encoding".equalsIgnoreCase(name)) {
252                        continue;
253                    } else if ("Content-Type".equalsIgnoreCase(name)) {
254                        list.add(getContentType());
255                    } else {
256                        list.add(name);
257                    }                
258                }
259                return Collections.enumeration(list);
260            }
261    
262            /**
263             * @see javax.servlet.http.HttpServletRequestWrapper#getHeaders(java.lang.String)
264             */
265            @SuppressWarnings("unchecked")
266            @Override
267            public Enumeration getHeaders(String name) {
268                if ("Content-Length".equalsIgnoreCase(name)) {
269                    return Collections.enumeration(Collections.EMPTY_SET);
270                } else if ("Content-MD5".equalsIgnoreCase(name)) {
271                    return Collections.enumeration(Collections.EMPTY_SET);
272                } else if ("Content-Encoding".equalsIgnoreCase(name)) {
273                    return Collections.enumeration(Collections.EMPTY_SET);
274                } else if ("Content-Type".equalsIgnoreCase(name)) {
275                    return Collections.enumeration(Collections.singleton(getContentType()));
276                } else {
277                    return super.getHeaders(name);
278                }
279            }
280    
281            /**
282             * @see javax.servlet.http.HttpServletRequestWrapper#getIntHeader(java.lang.String)
283             */
284            @Override
285            public int getIntHeader(String name) {
286                if ("Content-Length".equalsIgnoreCase(name)) {
287                    return -1;
288                } else if ("Content-MD5".equalsIgnoreCase(name)) {
289                    return -1;
290                } else if ("Content-Encoding".equalsIgnoreCase(name)) {
291                    return -1;
292                } else if ("Content-Type".equalsIgnoreCase(name)) {
293                    return -1;
294                } else {
295                    return super.getIntHeader(name);
296                }
297            }
298    
299            /**
300             * @see javax.servlet.ServletRequestWrapper#getCharacterEncoding()
301             */
302            @Override
303            public String getCharacterEncoding() {
304                return utf8 ? "utf-8" : null;
305            }
306    
307            /**
308             * @see javax.servlet.ServletRequestWrapper#getContentLength()
309             */
310            @Override
311            public int getContentLength() {
312                return -1;
313            }
314    
315            /**
316             * @see javax.servlet.ServletRequestWrapper#getContentType()
317             */
318            @Override
319            public String getContentType() {
320                return utf8 ? contentType + "; charset=utf-8" : contentType;
321            }
322    
323            /**
324             * @see javax.servlet.ServletRequestWrapper#getInputStream()
325             */
326            @Override
327            public ServletInputStream getInputStream() throws IOException {
328                return stream;
329            }
330    
331            /**
332             * @see javax.servlet.ServletRequestWrapper#getParameter(java.lang.String)
333             */
334            @Override
335            public String getParameter(String key) {
336                String[] arr = params.get(key);
337                if (arr == null) {
338                    return null;
339                } else {
340                    return arr[0];
341                }
342            }
343    
344            /**
345             * @see javax.servlet.ServletRequestWrapper#getParameterMap()
346             */
347            @Override
348            public Map getParameterMap() {
349                return params;
350            }
351    
352            /**
353             * @see javax.servlet.ServletRequestWrapper#getParameterNames()
354             */
355            @Override
356            public Enumeration getParameterNames() {
357                return Collections.enumeration(params.keySet());
358            }
359    
360            /**
361             * @see javax.servlet.ServletRequestWrapper#getParameterValues(java.lang.String)
362             */
363            @Override
364            public String[] getParameterValues(String key) {
365                return params.get(key);
366            }
367    
368            /**
369             * @see javax.servlet.ServletRequestWrapper#getReader()
370             */
371            @Override
372            public BufferedReader getReader() throws IOException {
373                CharsetDecoder dec = Charset.forName("UTF-8").newDecoder();
374                dec.onMalformedInput(CodingErrorAction.REPORT);
375                dec.onUnmappableCharacter(CodingErrorAction.REPORT);
376                Reader reader = new InputStreamReader(stream, dec);
377                return new BufferedReader(reader);
378            }
379            
380        }
381        
382    }