package com.smtscript.utils; import java.io.InputStream; import java.io.OutputStream; import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketTimeoutException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; //GET http://192.168.191.1:81/aaa HTTP/1.1 public class HttpProxyServer { private static final String CONTENT_LENGTH = "content-length:"; private static Pattern _patHttpHead = Pattern.compile("^(GET|POST)\\s+http://([^/:]+)(:\\d+)?(.+?)(\\s+HTTP/1.1)?$"); private ServerSocket _sockServer; private ExecutorService _threadPool; private volatile boolean _serverRunning = true; protected void onProxyRequest(String remoteHost, int remotePort, String remoteURL, String requestHead, byte[] requestBody, String responseHead, byte[] responseBody) { //System.out.println(String.format("proxy : %s:%d%s", remoteHost, remotePort, remoteURL)); } public int create(int port) throws Exception { _threadPool = Executors.newCachedThreadPool(); _sockServer = new ServerSocket(port); port = _sockServer.getLocalPort(); _sockServer.setSoTimeout(1000); _threadPool.execute(new Runnable(){ @Override public void run() { while(_serverRunning) { try { final Socket sockAccept = _sockServer.accept(); final Socket sockRemote = new Socket(); _threadPool.execute(new Runnable(){ @Override public void run() { try { try { StringBuilder sbRequest = new StringBuilder(); // 读取第一行请求 String line = readSocketLine(sockAccept.getInputStream()); // 解析第一行请求 Matcher m = _patHttpHead.matcher(line); if(!m.find()) throw new Exception("can't parse http head : " + line); String method = m.group(1); String remoteHost = m.group(2); String remotePort = m.group(3); String remoteUrl = m.group(4); String tail = m.group(5); sbRequest.append(String.format("%s %s%s\r\n", method, remoteUrl, tail == null ? "" : tail)); // 解析请求头信息 int contentLength = 0; InputStream isAccept = sockAccept.getInputStream(); while(true) { line = readSocketLine(isAccept); sbRequest.append(line); sbRequest.append("\r\n"); if(line.length() == 0) break; if(line.toLowerCase().startsWith(CONTENT_LENGTH)) contentLength = Integer.parseInt(line.substring(CONTENT_LENGTH.length()).trim()); } // 连接到远端 int remotePortV = SMTStatic.isNullOrEmpty(remotePort) ? 80 : Integer.parseInt(remotePort.substring(1)); sockRemote.connect(new InetSocketAddress(remoteHost, remotePortV)); // 将报文头发送到远程 OutputStream osRemote = sockRemote.getOutputStream(); osRemote.write(sbRequest.toString().getBytes()); // 如果存在附带参数,则将附带参数一并发到远程 byte[] dataRequest = null; if(contentLength > 0) { dataRequest = readSocketBytes(isAccept, contentLength); osRemote.write(dataRequest); } // 从远程接收数据 contentLength = 0; InputStream isRemote = sockRemote.getInputStream(); StringBuilder sbResponse = new StringBuilder(); while(true) { line = readSocketLine(isRemote); sbResponse.append(line); sbResponse.append("\r\n"); if(line.length() == 0) break; if(line.toLowerCase().startsWith(CONTENT_LENGTH)) contentLength = Integer.parseInt(line.substring(CONTENT_LENGTH.length()).trim()); } // 将反馈头发到接收方 OutputStream osAccept = sockAccept.getOutputStream(); osAccept.write(sbResponse.toString().getBytes()); // 如果存在附带参数,则接收附带数据 byte[] dataResponse = null; if(contentLength > 0) { dataResponse = readSocketBytes(isRemote, contentLength); osAccept.write(dataResponse); } // 通知后层 onProxyRequest(remoteHost, remotePortV, remoteUrl, sbRequest.toString(), dataRequest, sbResponse.toString(), dataResponse); } finally { sockAccept.close(); sockRemote.close(); } } catch(Exception ex) { throw new RuntimeException(ex); } } }); }catch(SocketTimeoutException e){ } catch (Exception e) { throw new RuntimeException(e); } } } }); return port; } public void close() throws Exception { _serverRunning = false; _threadPool.awaitTermination(Integer.MAX_VALUE, TimeUnit.DAYS); if(_sockServer != null) _sockServer.close(); } private String readSocketLine(InputStream is) throws Exception { byte[] data = new byte[4096]; for(int len = 0; len < data.length; len ++) { int ch = is.read(); if(ch < 0) throw new Exception("can't find LOF in receive HTTP head"); if(ch == '\r') continue; if(ch == '\n') { String s = new String(data, 0, len - 1); return s; } data[len] = (byte)ch; } throw new Exception("receive HTTP head more than 4096 bytes"); } private byte[] readSocketBytes(InputStream is, int length) throws Exception { byte[] data = new byte[length]; for(int i = 0; i < length; i ++) { int ch = is.read(); if(ch < 0) throw new Exception("can't read socket bytes length"); data[i] = (byte)ch; } return data; } }