基于TensorFlow的交通标志识别Android APP的设计

  • Post author:
  • Post category:其他


一、创建数据集,作者整理了57种类别的交通标志图片

二、训练模型,作者使用TensorFlow深度学习框架训练所需的模型文件

    def run(self):
        # 1. 加载数据集
        train_dataset, validate_dataset, class_names = self.m_data_load(self.train_dir, 224, 224, 16)
        self.ui.train_dataset = train_dataset
        self.ui.validate_dataset = validate_dataset
        # 2. 加载模型
        model = self.model_load(class_num=len(class_names))
        self.ui.model = model

        self.signal.emit(str(len(class_names)), class_names)

    # 模型加载
    def model_load(IMG_SHAPE=(224, 224, 3), class_num=214):
        base_model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
        base_model.trainable = False

        model = tf.keras.models.Sequential([
            tf.keras.layers.experimental.preprocessing.Rescaling(1. / 127.5, offset=-1, input_shape=(224, 224, 3)),
            base_model,
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(class_num, activation='softmax')
        ])

        # 输出模型信息
        model.summary()
        model.compile(optimizer='adam', loss='categorical_crossentropy',
                      metrics=['accuracy'])
        return model

三、将训练好的模型文件导入Android工程的assets文件夹下。

四、编写Android代码

public class MainActivity extends Activity {

    // 类别的数量
    private int number = 57;
    // 类别名称
    private String class_names[] = {
            "限速15", "限速30", "限速40", "限速50", "限速60", "限速70", "限速80", "禁止直行和左转", "禁止直行和右转", "禁止直行", "禁止左转",
            "禁止左转和右转", "禁止右转", "禁止超车", "禁止掉头", "禁止机动车通行", "禁止鸣喇叭", "解除限速40", "解除限速50",
            "直行和右转", "直行", "左转", "左转和右转", "右转", "靠左侧道路行驶", "靠右侧道路行驶", "环岛行驶", "机动车行驶", "鸣喇叭", "非机动车行驶 ", "掉头", "注意避让",
            "注意红绿灯", "注意危险", "注意行人", "注意非机动车", "注意儿童", "注意急右转弯", "注意急左转弯", "注意下坡", "注意上坡",
            "注意慢行", "T型交叉", "T型交叉", "村庄", "反向弯路", "无人看守铁路道口", "施工", "连续弯路", "有人看守铁路道口", "事故易发路段",
            "停车让行", "禁止通行", "禁止停车", "禁止驶入", "减速让行", "停车检查"
        };
    // 输入
    private int[] input = {1, 224, 224, 3};
    // 输出
    private float[][] output = new float[1][number];

    private Interpreter interpreter;
    private Bitmap bitmap;
    private ImageView iv_vegetable;


    private String[] neededPermissions = new String[]{
            Manifest.permission.READ_PHONE_STATE
    };
    private TextView tv_text;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);

        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.LOLLIPOP) {
            Window window = this.getWindow();
            window.clearFlags(WindowManager.LayoutParams.FLAG_TRANSLUCENT_STATUS);
            window.getDecorView().setSystemUiVisibility(View.SYSTEM_UI_FLAG_LAYOUT_FULLSCREEN
                    | View.SYSTEM_UI_FLAG_LAYOUT_STABLE);
            window.addFlags(WindowManager.LayoutParams.FLAG_DRAWS_SYSTEM_BAR_BACKGROUNDS);
            window.setStatusBarColor(Color.GRAY);

        }
        setContentView(R.layout.activity_main);

        /*
         * 在选择图片的时候,在android 7.0及以上通过FileProvider获取Uri,不需要文件权限
         */
        if (Build.VERSION.SDK_INT < Build.VERSION_CODES.N) {
            List<String> permissionList = new ArrayList<>(Arrays.asList(neededPermissions));
            permissionList.add(Manifest.permission.READ_EXTERNAL_STORAGE);
            neededPermissions = permissionList.toArray(new String[0]);
        }

        initView();

        TFLiteLoader loader = TFLiteLoader.newInstance(this);
        interpreter = loader.get();

        showToast("模型加载成功!");

        bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.orange);
    }

五、实现效果

基于TensorFlow的交通标志识别

六、完整源码下载

链接:https://pan.baidu.com/s/1vhtkevbQdbt3nuB6YOTzZA?pwd=99gp

提取码:99gp



版权声明:本文为lilihexiaoxiangege原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。