001/*
002 * Licensed under the Apache License, Version 2.0 (the "License");
003 * you may not use this file except in compliance with the License.
004 * You may obtain a copy of the License at
005 *
006 *     http://www.apache.org/licenses/LICENSE-2.0
007 *
008 * Unless required by applicable law or agreed to in writing, software
009 * distributed under the License is distributed on an "AS IS" BASIS,
010 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
011 * See the License for the specific language governing permissions and
012 * limitations under the License.
013 */
014package org.gbif.ws.server;
015
016import java.io.BufferedReader;
017import java.io.ByteArrayInputStream;
018import java.io.IOException;
019import java.io.InputStreamReader;
020import java.nio.charset.StandardCharsets;
021import java.util.ArrayList;
022import java.util.Collections;
023import java.util.Enumeration;
024import java.util.List;
025import java.util.Optional;
026
027import javax.servlet.ServletInputStream;
028import javax.servlet.http.HttpServletRequest;
029import javax.servlet.http.HttpServletRequestWrapper;
030
031import org.apache.commons.io.IOUtils;
032import org.apache.commons.lang3.StringUtils;
033import org.springframework.http.HttpHeaders;
034
035public class GbifHttpServletRequestWrapper extends HttpServletRequestWrapper {
036
037  private String content;
038
039  private HttpHeaders httpHeaders;
040
041  private HttpServletRequest wrappedRequest;
042
043  public GbifHttpServletRequestWrapper(HttpServletRequest request) {
044    this(request, false);
045  }
046
047  /**
048   *
049   */
050  public GbifHttpServletRequestWrapper(HttpServletRequest request, boolean wrapContent) {
051    this(request, null, wrapContent);
052    if (!wrapContent) {
053      this.wrappedRequest = request;
054    }
055  }
056
057  public GbifHttpServletRequestWrapper(
058      HttpServletRequest request, String contentAsString, boolean wrapContent) {
059    super(request);
060
061    try {
062      if (StringUtils.isNotEmpty(contentAsString)) {
063        content = contentAsString;
064      } else if (request.getInputStream() != null && wrapContent) {
065        content = IOUtils.toString(request.getInputStream(), request.getCharacterEncoding());
066      } else {
067        content = null;
068      }
069    } catch (IOException e) {
070      throw new RuntimeException("Stream can't be read", e);
071    }
072
073    httpHeaders = getHttpHeaders(request);
074  }
075
076  @Override
077  public ServletInputStream getInputStream() throws IOException {
078    if (wrappedRequest != null) {
079      return wrappedRequest.getInputStream();
080    } else {
081      final ByteArrayInputStream byteArrayInputStream =
082          new ByteArrayInputStream(content.getBytes(StandardCharsets.UTF_8));
083      return new DelegatingServletInputStream(byteArrayInputStream);
084    }
085  }
086
087  @Override
088  public BufferedReader getReader() throws IOException {
089    return new BufferedReader(new InputStreamReader(this.getInputStream(), StandardCharsets.UTF_8));
090  }
091
092  private HttpHeaders getHttpHeaders(HttpServletRequest request) {
093    final HttpHeaders requestHeaders = new HttpHeaders();
094    Enumeration<String> headerNames = request.getHeaderNames();
095
096    if (headerNames != null) {
097      while (headerNames.hasMoreElements()) {
098        String currentHeaderName = headerNames.nextElement();
099        requestHeaders.set(currentHeaderName, request.getHeader(currentHeaderName));
100      }
101    }
102
103    return requestHeaders;
104  }
105
106  public String getContent() {
107    return content;
108  }
109
110  public HttpHeaders getHttpHeaders() {
111    return new HttpHeaders(httpHeaders);
112  }
113
114  public void overwriteLanguageHeader(String newValue) {
115    httpHeaders.set(HttpHeaders.ACCEPT_LANGUAGE, newValue);
116  }
117
118  @Override
119  public String getHeader(String name) {
120    if (getHttpHeaders().containsKey(name)) {
121      return getHttpHeaders().getFirst(name);
122    }
123    return super.getHeader(name);
124  }
125
126  @Override
127  public Enumeration<String> getHeaderNames() {
128    return Collections.enumeration(httpHeaders.keySet());
129  }
130
131  @Override
132  public Enumeration<String> getHeaders(String name) {
133    List<String> values = new ArrayList<>();
134    if (httpHeaders.containsKey(name)) {
135      Optional.ofNullable(httpHeaders.get(name)).ifPresent(values::addAll);
136    }
137    return Collections.enumeration(values);
138  }
139}