JAVA 实现TCP请求转发

  • Post author:
  • Post category:java


背景: 前端时间做了关于区块链项目, 需要遍历区块, 由于有端口限制, 故不能直接访问,于是就写下了这个转发工具。 Nginx 同样能完成此项工作, 但我选择了JAVA。

ForwardMsg 类为转发实体信息

import java.io.BufferedReader;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.URLDecoder;
import java.net.UnknownHostException;
import java.util.HashMap;
import java.util.Map;
import java.util.Vector;
import java.util.regex.Matcher;
import java.util.regex.Pattern;


/**
 * @Computer XYSM
 * @Program: cmd
 * @Package: com.hsy.cmd
 * @Description
 * @Author System
 * @CreateTime 2021-01-19 09:50
 * @Version 1.0.0
 */
public class Forward {

    public static void main(String[] args) {
        ForwardMsg cmd  = new ForwardMsg(10253, 36253, "127.0.0.1", "CMD 客户端");
        new Thread(() -> new Forward().accept(cmd)).start();

        new Thread(() -> new MessageListener().listening()).start();
    }

    /** 转发实例 */
    private ForwardMsg msg;

    private void accept(ForwardMsg forwardMsg){
        try {
            System.out.println(forwardMsg.toString());
            this.msg = forwardMsg;
            ServerSocket server = new ServerSocket(forwardMsg.localPort);
            while (true){
                Socket socket;
                try{
                    socket = server.accept();
                    long ts = System.currentTimeMillis();
                    for (Connect c : MessageListener.vector) {
                        if (ts - c.timestamp > 30000){
                            close(c);
                        }else {
                            break;
                        }
                    }
                    MessageListener.vector.add(new Connect(socket, this.msg));
                }catch(Exception e){
                    e.printStackTrace();
                }
            }
        } catch (IOException e) {

        }
    }


    public static void close(Closeable  ... closeables){
        for (Closeable closeable : closeables) {
            try{
                closeable.close();
            }catch(Exception e){

            }
        }
    }



}

class MessageListener{

    public MessageListener() {
        super();
    }

    public static volatile Vector<Connect> vector = new Vector<Connect>();

    public void listening(){
        System.out.println("消息队列工作中。。。。");
        while (true){
            if (vector.size() == 0){
                continue;
            }
            Connect pack = vector.remove(0);
            new Thread(()->{
                try{
                    System.out.println(pack.toString());
                    /* 转发请求流 */
                    this.forwardStream(pack);
                    /* 转发响应流 */
                    this.transferTo(pack.target, pack.source);
                }catch(Exception e){
                    e.printStackTrace();
                }finally {
                    Forward.close(pack);
                }
            }).start();
        }
    }

    public void transferTo(Socket source, Socket target){
        try{
            byte[] buffer = new byte[512];
            int read;
            InputStream  is = source.getInputStream();
            OutputStream os = target.getOutputStream();
            while ((read = is.read(buffer)) >= 0) {
                os.write(buffer, 0, read);
                if (read < buffer.length){
                    break;
                }
            }
        }catch(Exception e){}

    }

    private void forwardStream(Connect connect){
        try{
            /* 读取头部 */
            HttpReader httpReader = new HttpReader(connect.source);
            if (connect.target == null){
                httpReader.setConnect(connect);
            }else {
                httpReader.setTargetHost(connect.target.getRemoteSocketAddress().toString().substring(1));
            }
            /* 解析body */
            httpReader.parseBody();
            /* 转发报文 */
            connect.target.getOutputStream().write(httpReader.getSource().getBytes());
        }catch(Exception e){
            e.printStackTrace();
            System.out.println("write error");
        }
    }
}


class Connect implements Closeable{
    public  long  timestamp;
    public Socket source;
    public Socket target;

    public Connect(Socket source, ForwardMsg msg) throws IOException {
        this.timestamp  = System.currentTimeMillis();
        this.source     = source;
        this.target     = msg.targetUri == null ? null : new Socket(msg.targetUri, msg.targetPort);
    }

    @Override
    public void close() throws IOException {
        if (this.target != null){
            this.target.close();
        }
        if (this.source != null){
            this.source.close();
        }
    }

    @Override
    public String toString() {
        return "请求转发: " + source.toString() + " --> " + (target == null ? "请求指定" : target.toString());
    }
}


class ForwardMsg{

    public int localPort;

    public int targetPort;

    public String targetUri;

    public String title;

    public ForwardMsg(int localPort, int targetPort, String targetUri) {
        this.localPort  = localPort;
        this.targetPort = targetPort;
        this.targetUri  = targetUri;
    }

    public ForwardMsg(int localPort, int targetPort, String targetUri, String title) {
        this.localPort  = localPort;
        this.targetPort = targetPort;
        this.targetUri  = targetUri;
        this.title      = title;
    }

    @Override
    public String toString() {
        return this.title + ":  " + this.localPort + " -> [" + this.targetUri + ":" + this.targetPort + ']';
    }
}


class HttpReader{

    private static final char   SPACE = ' ';
    private static final String BLANK = "";
    private static final String CONTENT_LENGTH = "content-length";
    private static final String NEW_LINE = "\r\n";
    private static final String HEAD_SEPARATOR = ": ";

    /** 原始字符串 */
    private StringBuilder source = new StringBuilder(512);
    /** 请求头部 */
    private Map<String,Object> hearders = new HashMap<>(10);
    /** 请求参数 */
    private StringBuilder body = new StringBuilder();
    /** 请求输入流 */
    private BufferedReader reader;
    /** 已读取数据长度 */
    private int readLength = 0;
    /** 请求方式 */
    private String method;
    /** URI */
    private String uri;
    /** 版本 */
    private String version;
    /** URL中的目标地址 */
    private Connect connect;
    /** target host */
    private String targetHost;


    public HttpReader(InputStream inputStream) {
        this.reader = new BufferedReader(new InputStreamReader(inputStream));
    }

    public HttpReader(Socket socket) throws IOException {
        this(socket.getInputStream());
    }

    private String readLine() {
        try {
            String line = this.reader.readLine();
            if (line == null){
                return BLANK;
            }
            this.source.append(line).append(NEW_LINE);
            this.readLength += line.getBytes().length + 2;
            return line;
        } catch (IOException e) {
            throw new RuntimeException("HTTP parse error");
        }
    }

    public static final Pattern ISSUED_DATA  = Pattern.compile("issuedUri=[^&]+");
    private void parseProtocol(){
        String line = null;
        while (line == null){
            line = this.readLine();
        }
        int size = line.length(), i = -1, lastBeginIndex = 0, time = 0;
        String value;
        while (++ i < size){
            if (line.charAt(i) == SPACE) {
                value = line.substring(lastBeginIndex, i);
                lastBeginIndex = i + 1;
                if (time == 0){
                    time ++;
                    this.method = value;
                }else if (time == 1){
                    this.uri = value;
                    this.version = line.substring(lastBeginIndex);
                }
            }
        }
        /* 从请求端读取目标地址 */
        if (this.connect != null){
            Matcher matcher = ISSUED_DATA.matcher(this.uri);
            if (matcher.find()) {
                try {
                    String sourceUri = matcher.group(0);
                    this.uri = this.uri.replace(sourceUri, BLANK);
                    if (this.uri.charAt(this.uri.length() - 1) == '?'){
                        this.uri = this.uri.substring(0, this.uri.length() - 1);
                    }
                    String ip;
                    int port;
                    String uri = URLDecoder.decode(sourceUri.substring(10), "utf-8");
                    int index;
                    if ((index = uri.indexOf(":")) > -1 && index < uri.length()){
                        ip = uri.substring(0, index);
                        port = Integer.parseInt(uri.substring(index + 1));
                        this.targetHost = ip + ':' + port;
                    }else {
                        port = 80;
                        ip = uri;
                        this.targetHost = ip;
                    }
                    this.connect.target = new Socket(ip, port);
                    this.source.delete(0, this.source.length()).append(this.method + " " + this.uri + " " +this.version + NEW_LINE);
                } catch (UnsupportedEncodingException e) {
                    e.printStackTrace();
                } catch (UnknownHostException e) {
                    e.printStackTrace();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }


    private int parseHearders(){
        String line;
        int index;
        while (true){
            line = this.readLine();
            index = line.indexOf(HEAD_SEPARATOR);
            if (index > -1){
                String key      = line.substring(0, index).toLowerCase()
                        , value    = line.substring(index + 2);
                if (key.equals("host")){
                    int start = this.source.indexOf(value);
                    this.source.replace(start, start + value.length(), this.targetHost);

                }
                this.hearders.put(key, value);
            }
            if (line.length() == 0){
                break;
            }
        }
        return this.hearders.size();
    }


    /**
     * PC
     * @title       content-type 长度,为字节长度, 注意
     * @description
     * @params
     *    @param
     * @since        - 2021/3/15 11:53
     * @throws
     * @return      Result<>
     *
     * * * * * T * I * M * E * * * * *
     * 创建及修改内容
     * @author      KB
     * @createTime  2021/3/15 11:53
     * @editor      KB
     * @updateDesc  原著
     * @updateTime  2021/3/15 11:53
     */
    public int parseBody(){

        this.parseProtocol();
        this.parseHearders();
		Object l = this.hearders.get(CONTENT_LENGTH);
        int readSize = 0, contentSize = l == null ? 0 : Integer.parseInt(l.toString());
        if (contentSize == 0){
            return 0;
        }
        char[] c = new char[contentSize % 512 + 1];
        int length;
        try{
            while ((length = this.reader.read(c)) > -1){
                body.append(c, 0, length);
                source.append(c, 0, length);
                readSize += length;
                //todo 中文字符长度 != 字节长度, 因为是char数组,若统计,相对麻烦, 可以使用 new String(c, 0, length).getBytes().length()
                this.readLength += length;
                if (readSize >= contentSize || length < c.length){
                    break;
                }
            }
        }catch(Exception e){
            e.printStackTrace();
        }
        return readSize;
    }

    public String getSource() {
        return source.toString();
    }

    public int getBodyByteSize() {
        return Integer.parseInt(this.hearders.get(CONTENT_LENGTH).toString());
    }

    public BufferedReader getReader() {
        return reader;
    }

    public int getReadLength() {
        return readLength;
    }

    public void setConnect(Connect connect) {
        this.connect = connect;
    }

    public void setTargetHost(String targetHost) {
        this.targetHost = targetHost;
    }
}



版权声明:本文为li18883754474原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。