依赖
<properties>
<flink.version>1.12.5</flink.version>
<scala.version>2.12</scala.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-java</artifactId>
<version>${flink.version}</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-streaming-java_${scala.version}</artifactId>
<version>${flink.version}</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-clients_${scala.version}</artifactId>
<version>${flink.version}</version>
</dependency>
<!--hanlp-->
<dependency>
<groupId>com.hankcs</groupId>
<artifactId>hanlp</artifactId>
<version>portable-1.7.6</version>
</dependency>
</dependencies>
驯练
negDir和posDir是分别存放负面语料和正面语料的文件夹,里面的语料是txt文件
public class NlpTrain {
public static void main(String[] args) {
final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
String negDir="E:\\nlp\\sample\\neg";
String posDir="E:\\nlp\\sample\\pos";
DataStream<TranObj> negData=env.readFile(FileRead.getDirTextInputFormat(negDir),negDir)
.assignTimestampsAndWatermarks(new WatermarkStrageWithTimestamp<>())
.filter(new FilterFunction<String>() {
@Override
public boolean filter(String s) throws Exception {
return s.length()>0;
}
})
.map(new MapFunction<String, TranObj>() {
@Override
public TranObj map(String s) throws Exception {
return new TranObj("负向", s);
}
});
DataStream<TranObj> posData=env.readFile(FileRead.getDirTextInputFormat(posDir),posDir)
.assignTimestampsAndWatermarks(new WatermarkStrageWithTimestamp<>())
.filter(new FilterFunction<String>() {
@Override
public boolean filter(String s) throws Exception {
return s.length()>0;
}
})
.map(new MapFunction<String, TranObj>() {
@Override
public TranObj map(String s) throws Exception {
return new TranObj("正向", s);
}
});;
negData.union(posData).keyBy(TranObj::getType)
.window(EventTimeSessionWindows.withGap(Time.minutes(1)))
.aggregate(new AggregateFunction<TranObj, List<TranObj>, Map<String, List<String>>>() {
@Override
public List<TranObj> createAccumulator() {
return new ArrayList<>();
}
@Override
public List<TranObj> add(TranObj tranObj, List<TranObj> tranObjs) {
tranObjs.add(tranObj);
return tranObjs;
}
@Override
public Map<String, List<String>> getResult(List<TranObj> tranObjs) {
List<String> list=new ArrayList<>();
String key=tranObjs.get(0).getType();
for (TranObj obj : tranObjs){
list.add(obj.getSample());
}
Map<String, List<String>> map=new HashMap<>();
map.put(key,list);
return map;
}
@Override
public List<TranObj> merge(List<TranObj> tranObjs, List<TranObj> acc1) {
tranObjs.addAll(acc1);
return tranObjs;
}
}).windowAll(EventTimeSessionWindows.withGap(Time.minutes(1)))
.reduce(new ReduceFunction<Map<String, List<String>>>() {
@Override
public Map<String, List<String>> reduce(Map<String, List<String>> map1, Map<String, List<String>> map2) throws Exception {
for (String key:map1.keySet()){
if (map2.containsKey(key)){
map1.get(key).addAll(map2.get(key));
}
}
for (String key:map2.keySet()){
if (!map1.containsKey(key)){
map1.put(key,map2.get(key));
}
}
return map1;
}
}).map(new MapFunction<Map<String, List<String>>, Map<String, String[]>>() {
@Override
public Map<String, String[]> map(Map<String, List<String>> stringListMap) throws Exception {
Map<String, String[]> tranMap=new HashMap<>();
for (String k:stringListMap.keySet()){
String[] arr=new String[stringListMap.get(k).size()];
tranMap.put(k,stringListMap.get(k).toArray(arr));
}
HanlpModel.trainNaiveBayesModel(tranMap, "E:\\log\\test.ser");
return tranMap;
}
});
try {
env.execute();
} catch (Exception e) {
e.printStackTrace();
}
}
}
HanlpModel.java
public class HanlpModel {
public static void trainNaiveBayesModel(Map<String, String[]> map, String path) {
IClassifier classifier = new NaiveBayesClassifier();
classifier.train(map);
NaiveBayesModel model = (NaiveBayesModel) classifier.getModel();
IOUtil.saveObjectTo(model, path);
}
public static NaiveBayesModel getNaiveBayesModel(String path){
return (NaiveBayesModel)IOUtil.readObjectFrom(path);
}
}
WatermarkStrageWithTimestamp.java
public class WatermarkStrageWithTimestamp<T> implements WatermarkStrategy<T> {
@Override
public WatermarkGenerator<T> createWatermarkGenerator(WatermarkGeneratorSupplier.Context context) {
return new BoundedOutOfOrdernessWatermarks<>(Duration.ofSeconds(10));
}
@Override
public TimestampAssigner<T> createTimestampAssigner(TimestampAssignerSupplier.Context context) {
return new TimestampAssigner<T>() {
@Override
public long extractTimestamp(T t, long l) {
return System.currentTimeMillis();
}
};
}
}
FileRead.java
public class FileRead {
public static TextInputFormat getDirTextInputFormat(String dir){
Path path = new Path(dir);
Configuration configuration = new Configuration();
configuration.setBoolean("recursive.file.enumeration", true);
TextInputFormat textInputFormat = new TextInputFormat(path);
textInputFormat.supportsMultiPaths();
textInputFormat.configure(configuration);
textInputFormat.setCharsetName("UTF-8");
return textInputFormat;
}
}
测试
public class Test {
public static void main(String[] args) {
final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.readTextFile("E:\\log\\test.txt").filter(new FilterFunction<String>() {
@Override
public boolean filter(String input) throws Exception {
return input.length()>0;
}
}).map(new MapFunction<String, DataObj>() {
@Override
public DataObj map(String input) throws Exception {
return new DataObj(input);
}
}).map(new MapFunction<DataObj, String>() {
@Override
public String map(DataObj dataObj) throws Exception {
IClassifier classifier = new NaiveBayesClassifier(HanlpModel.getNaiveBayesModel("E:\\log\\test.ser"));
return classifier.classify(dataObj.getContent())+":"+dataObj.getContent();
}
}).print();
try {
env.execute();
} catch (Exception e) {
e.printStackTrace();
}
}
}
版权声明:本文为GrassEva原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。