flink使用hanlp进行情感分析

  • Post author:
  • Post category:其他


依赖

<properties>
    <flink.version>1.12.5</flink.version>
    <scala.version>2.12</scala.version>
</properties>

<dependencies>
    <dependency>
        <groupId>org.apache.flink</groupId>
        <artifactId>flink-java</artifactId>
        <version>${flink.version}</version>
    </dependency>
    <dependency>
        <groupId>org.apache.flink</groupId>
        <artifactId>flink-streaming-java_${scala.version}</artifactId>
        <version>${flink.version}</version>
    </dependency>
    <dependency>
        <groupId>org.apache.flink</groupId>
        <artifactId>flink-clients_${scala.version}</artifactId>
        <version>${flink.version}</version>
    </dependency>

    <!--hanlp-->
    <dependency>
        <groupId>com.hankcs</groupId>
        <artifactId>hanlp</artifactId>
        <version>portable-1.7.6</version>
    </dependency>
</dependencies>

驯练

negDir和posDir是分别存放负面语料和正面语料的文件夹,里面的语料是txt文件

public class NlpTrain {
    public static void main(String[] args) {
        final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        String negDir="E:\\nlp\\sample\\neg";
        String posDir="E:\\nlp\\sample\\pos";
        DataStream<TranObj> negData=env.readFile(FileRead.getDirTextInputFormat(negDir),negDir)
                .assignTimestampsAndWatermarks(new WatermarkStrageWithTimestamp<>())
                .filter(new FilterFunction<String>() {
                    @Override
                    public boolean filter(String s) throws Exception {
                        return s.length()>0;
                    }
                })
                .map(new MapFunction<String, TranObj>() {
                    @Override
                    public TranObj map(String s) throws Exception {
                        return new TranObj("负向", s);
                    }
                });
        DataStream<TranObj> posData=env.readFile(FileRead.getDirTextInputFormat(posDir),posDir)
                .assignTimestampsAndWatermarks(new WatermarkStrageWithTimestamp<>())
                .filter(new FilterFunction<String>() {
                    @Override
                    public boolean filter(String s) throws Exception {
                        return s.length()>0;
                    }
                })
                .map(new MapFunction<String, TranObj>() {
                    @Override
                    public TranObj map(String s) throws Exception {
                        return new TranObj("正向", s);
                    }
                });;
        negData.union(posData).keyBy(TranObj::getType)
                .window(EventTimeSessionWindows.withGap(Time.minutes(1)))
                .aggregate(new AggregateFunction<TranObj, List<TranObj>, Map<String, List<String>>>() {
                    @Override
                    public List<TranObj> createAccumulator() {
                        return new ArrayList<>();
                    }

                    @Override
                    public List<TranObj> add(TranObj tranObj, List<TranObj> tranObjs) {
                        tranObjs.add(tranObj);
                        return tranObjs;
                    }

                    @Override
                    public Map<String, List<String>> getResult(List<TranObj> tranObjs) {
                        List<String> list=new ArrayList<>();
                        String key=tranObjs.get(0).getType();
                        for (TranObj obj : tranObjs){
                            list.add(obj.getSample());
                        }
                        Map<String, List<String>> map=new HashMap<>();
                        map.put(key,list);
                        return map;
                    }

                    @Override
                    public List<TranObj> merge(List<TranObj> tranObjs, List<TranObj> acc1) {
                        tranObjs.addAll(acc1);
                        return tranObjs;
                    }
                }).windowAll(EventTimeSessionWindows.withGap(Time.minutes(1)))
                .reduce(new ReduceFunction<Map<String, List<String>>>() {
                    @Override
                    public Map<String, List<String>> reduce(Map<String, List<String>> map1, Map<String, List<String>> map2) throws Exception {
                        for (String key:map1.keySet()){
                            if (map2.containsKey(key)){
                                map1.get(key).addAll(map2.get(key));
                            }
                        }
                        for (String key:map2.keySet()){
                            if (!map1.containsKey(key)){
                                map1.put(key,map2.get(key));
                            }
                        }
                        return map1;
                    }
                }).map(new MapFunction<Map<String, List<String>>, Map<String, String[]>>() {
                    @Override
                    public Map<String, String[]> map(Map<String, List<String>> stringListMap) throws Exception {
                        Map<String, String[]> tranMap=new HashMap<>();
                        for (String k:stringListMap.keySet()){
                            String[] arr=new String[stringListMap.get(k).size()];
                            tranMap.put(k,stringListMap.get(k).toArray(arr));
                        }
                        HanlpModel.trainNaiveBayesModel(tranMap, "E:\\log\\test.ser");
                        return tranMap;
                    }
        });

        try {
            env.execute();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

HanlpModel.java

public class HanlpModel {
    public static void trainNaiveBayesModel(Map<String, String[]> map, String path) {
        IClassifier classifier = new NaiveBayesClassifier();
        classifier.train(map);
        NaiveBayesModel model = (NaiveBayesModel) classifier.getModel();
        IOUtil.saveObjectTo(model, path);
    }

    public static NaiveBayesModel getNaiveBayesModel(String path){
        return (NaiveBayesModel)IOUtil.readObjectFrom(path);
    }

}

WatermarkStrageWithTimestamp.java

public class WatermarkStrageWithTimestamp<T> implements WatermarkStrategy<T> {
    @Override
    public WatermarkGenerator<T> createWatermarkGenerator(WatermarkGeneratorSupplier.Context context) {
        return new BoundedOutOfOrdernessWatermarks<>(Duration.ofSeconds(10));
    }

    @Override
    public TimestampAssigner<T> createTimestampAssigner(TimestampAssignerSupplier.Context context) {
        return new TimestampAssigner<T>() {
            @Override
            public long extractTimestamp(T t, long l) {
                return System.currentTimeMillis();
            }
        };
    }
}

FileRead.java

public class FileRead {
    public static TextInputFormat getDirTextInputFormat(String dir){
        Path path = new Path(dir);
        Configuration configuration = new Configuration();
        configuration.setBoolean("recursive.file.enumeration", true);
        TextInputFormat textInputFormat = new TextInputFormat(path);
        textInputFormat.supportsMultiPaths();
        textInputFormat.configure(configuration);
        textInputFormat.setCharsetName("UTF-8");
        return textInputFormat;
    }
}

测试

public class Test {
    public static void main(String[] args) {
        final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.readTextFile("E:\\log\\test.txt").filter(new FilterFunction<String>() {
            @Override
            public boolean filter(String input) throws Exception {
                return input.length()>0;
            }
        }).map(new MapFunction<String, DataObj>() {
            @Override
            public DataObj map(String input) throws Exception {
                return new DataObj(input);
            }
        }).map(new MapFunction<DataObj, String>() {
            @Override
            public String map(DataObj dataObj) throws Exception {
                IClassifier classifier = new NaiveBayesClassifier(HanlpModel.getNaiveBayesModel("E:\\log\\test.ser"));
                return classifier.classify(dataObj.getContent())+":"+dataObj.getContent();
            }
        }).print();
        try {
            env.execute();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}



版权声明:本文为GrassEva原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。