프로젝트/알약인식 앱💊

[Pill Project] TFlite, object detection 모델 안드로이드에 이식하기

핸디_ 2021. 9. 4. 01:07

.pb 모델형태는 안드로이드 이식을 더 이상 지원하지 않는 것 같았기 때문에 모델을 tflite형태로 변환하였다.

개인적 의견으로는, tflite모델로 변환했다면 사실 절반 이상은 완료했다고 봐도 무방하다.^^

 

과정을 간단하게만 기록해보자면

 

📌 안드로이드의 assets 폴더에 만든 모델.tflite를 복사해서 넣어준다.

mylabel.txt라는 텍스트파일도 assets에 넣어주는데, 이 텍스트파일에는 모델에서 구분하려는 각 클래스명을 한줄에 하나씩 적어주면 된다. 

 

📌MainActivity 수정

우리의 목적은

0) 모델 불러오기
1) 직접 사진을 찍거나 앨범에서 사진을 불러와서
2) 불러온 사진을 모델에 넘겨줌
3) 입력으로 들어온 사진을 모델에 넣어서 결과를 예측
4) 예측한 결과를 바탕으로 bounding box를 만들어 원본사진에 pill과 text표시

각 단계에서 중요한 코드들은 다음과 같다.

 

0) 모델 불러오기

 

try { //-> onCreate()안에다가 작성
            tflite = loadmodelfile(this@MainActivity)?.let { Interpreter(it) }
        } catch (e: java.lang.Exception) {
            e.printStackTrace()
        }
        
   private fun loadmodelfile(activity: Activity): MappedByteBuffer? { 
   //model load를 위한 file loading
        val fileDescriptor: AssetFileDescriptor =
            activity.getAssets().openFd("converted_original_model.tflite")

        val inputStream = FileInputStream(fileDescriptor.getFileDescriptor())
        val fileChannel: FileChannel = inputStream.getChannel()
        val startoffset: Long = fileDescriptor.getStartOffset()
        val declaredLength: Long = fileDescriptor.getDeclaredLength()
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startoffset, declaredLength)
    }

 

1) 직접 사진을 찍거나 앨범에서 사진을 불러오기 2) 모델에 넘겨주기

 captureImageFab.setOnClickListener(this)
 ButtonFromAlbum.setOnClickListener(this)
 
 --------------------------------------------
 
 override fun onClick(v: View?) {
        when (v?.id) {
            R.id.captureImageFab -> {
                try {
                    dispatchTakePictureIntent()
                } catch (e: ActivityNotFoundException) {
                    Log.e(TAG, e.message.toString())
                }
            }
            R.id.ButtonFromAlbum -> {
                try {
                    openGallery()
                } catch (e: ActivityNotFoundException) {
                    Log.e(TAG, e.message.toString())
                }
            }
          }
          
  --------------------------------------------
    private fun dispatchTakePictureIntent() {
        Intent(MediaStore.ACTION_IMAGE_CAPTURE).also { takePictureIntent ->
            // Ensure that there's a camera activity to handle the intent
            takePictureIntent.resolveActivity(packageManager)?.also {
                // Create the File where the photo should go
                val photoFile: File? = try {
                    createImageFile()
                } catch (e: IOException) {
                    Log.e(TAG, e.message.toString())
                    null
                }
                // Continue only if the File was successfully created
                photoFile?.also {
                    val photoURI: Uri = FileProvider.getUriForFile(
                        this,
                        "org.tensorflow.codelabs.objectdetection.fileprovider",
                        it
                    )
                    takePictureIntent.putExtra(MediaStore.EXTRA_OUTPUT, photoURI)
                    startActivityForResult(takePictureIntent, REQUEST_IMAGE_CAPTURE)
                }
            }
        }
    }
    
        private fun openGallery() {
        val intent = Intent(Intent.ACTION_PICK)
        intent.type = MediaStore.Images.Media.CONTENT_TYPE
        startActivityForResult(intent, REQ_GALLERY)
    }

 

3) 모델에 넣어서 결과 예측하기

 if (resultCode == Activity.RESULT_OK) {
            when (requestCode) {
                REQUEST_IMAGE_CAPTURE -> {
                    setViewAndDetect(getCapturedImage())
                }
                REQ_GALLERY -> {
                    var dataUri = data?.data
                    try {
                        var bitmap: Bitmap =
                            MediaStore.Images.Media.getBitmap(this.contentResolver, dataUri)
                        setViewAndDetect(bitmap)//TODO 갤러리에서 이미지 띄우는 거
                    } catch (e: Exception) {
                        Toast.makeText(this, "$e", Toast.LENGTH_LONG).show()

                    }

                }
            }
        }
        
        //setViewAndDetect : 비트맵 형태의 이미지를 파라마터로 받아서
        //이미지를 뷰에 띄우고 object detection을 부름
    private fun getCapturedImage(): Bitmap {
    //카메라에서 찍은 사진을 decoding하고 자름
    
        // Get the dimensions of the View
        val targetW: Int = inputImageView.width
        val targetH: Int = inputImageView.height

        val bmOptions = BitmapFactory.Options().apply {
            // Get the dimensions of the bitmap
            inJustDecodeBounds = true

            BitmapFactory.decodeFile(currentPhotoPath, this)

            val photoW: Int = outWidth
            val photoH: Int = outHeight

            // Determine how much to scale down the image
            val scaleFactor: Int = max(1, min(photoW / targetW, photoH / targetH))

            // Decode the image file into a Bitmap sized to fill the View
            inJustDecodeBounds = false
            inSampleSize = scaleFactor
            inMutable = true
        }
        val exifInterface = ExifInterface(currentPhotoPath)
        val orientation = exifInterface.getAttributeInt(
            ExifInterface.TAG_ORIENTATION,
            ExifInterface.ORIENTATION_UNDEFINED
        )

        val bitmap = BitmapFactory.decodeFile(currentPhotoPath, bmOptions)
        return when (orientation) {
            ExifInterface.ORIENTATION_ROTATE_90 -> {
                rotateImage(bitmap, 90f)
            }
            ExifInterface.ORIENTATION_ROTATE_180 -> {
                rotateImage(bitmap, 180f)
            }
            ExifInterface.ORIENTATION_ROTATE_270 -> {
                rotateImage(bitmap, 270f)
            }
            else -> {
                bitmap
            }
        }
    }
 private fun openGallery() {
 //갤러리에서 사진 가져오기
        val intent = Intent(Intent.ACTION_PICK)
        intent.type = MediaStore.Images.Media.CONTENT_TYPE
        startActivityForResult(intent, REQ_GALLERY)
    }

 

 4) 결과를 바탕으로 bounding box 표시

private fun drawDetectionResult(
        bitmap: Bitmap,
        detectionResults: List<DetectionResult>
    ): List<CropResult> {
        val outputBitmap = bitmap.copy(Bitmap.Config.ARGB_8888, true)

        val deviceWidth = resources.displayMetrics.widthPixels
        val deviceHeight = resources.displayMetrics.heightPixels
        val density = resources.displayMetrics.densityDpi

        val bitmapWidth = outputBitmap.width
        val bitmapHeight = outputBitmap.height
        Log.e(TAG, "width: ${bitmapWidth}")
        Log.e(TAG, "height: ${bitmapHeight}")

        val scaleHeight = deviceWidth * bitmapHeight / bitmapWidth
        //val resizeBp = Bitmap.createScaledBitmap(outputBitmap,deviceWidth, scaleHeight, true)

        //val resultBitmap :List<CropResult> = ArrayList<CropResult>()
        val resultBitmap = mutableListOf<CropResult>()

        val canvas = Canvas(outputBitmap)
        val pen = Paint()
        pen.textAlign = Paint.Align.LEFT

        detectionResults.forEach {
            // draw bounding box
            pen.color = Color.RED
            pen.strokeWidth = 8F
            pen.style = Paint.Style.STROKE
            val box = it.boundingBox
            canvas.drawRect(box, pen)

            val boxWidth = box.width()
            val boxHeight = box.height()
            //사진 자르기
            Log.e(TAG, "top: ${box.top}")
            Log.e(TAG, "left: ${box.left}")
            Log.e(TAG, "boxwidth: ${boxWidth}")
            Log.e(TAG, "boxheight: ${boxHeight}")
            val cropBitmap = Bitmap.createBitmap(
                outputBitmap,
                box.left.toInt(),
                box.top.toInt(),
                boxWidth.toInt(),
                boxHeight.toInt()
            )
            println("hello")
            println("text:" + it.text)
            if (it.text == "pill") {
                println("goto pill")
                //pillBitmap = Bitmap.createBitmap(cropBitmap)
                var bp = Bitmap.createBitmap(cropBitmap)
                var labeltext = "pill"
                resultBitmap.add(CropResult(bp, labeltext))
            } else if (it.text == "text") {
                println("goto text")
                var bp = Bitmap.createBitmap(cropBitmap)
                var labeltext = "text"
                resultBitmap.add(CropResult(bp, labeltext))

            }


        }
        return resultBitmap

    }

 

tensorflow lite model 공식문서 가이드와 함께 참고하여 코드를 작성하면 다음과 같이 detect결과가 잘 나오는 걸 볼 수 있다.