Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,49 @@ allprojects {
}
```

project build.gradle
```groovy
version catalogs
```
[versions]
moko-tensorflow = "<latest-version>"

[libraries]
moko-tensorflow = { module = "dev.icerock.moko:tensorflow", version.ref = "moko-tensorflow" }
```

If using default KMP plugin, type in your project.gradle.kts
```kotlin
kotlin {
cocoapods {
// other cocoapods configurations here
pod("TensorFlowLiteObjC") {
moduleName = "TFLTensorFlowLite"
}
// Or in non-exported module
pod(name="TensorFlowLiteObjC", linkOnly = true, moduleName = "TFLTensorFlowLite")
}
sourceSets {
val commonMain by getting {
dependencies {
api(libs.moko.tensorflow)
}
}
}
}
```
If using default moko gradle plugin, type in your project.gradle.kts
```kotlin
dependencies {
commonMainApi("dev.icerock.moko:tensorflow:0.3.0")
commonMainApi(libs.moko.tensorflow)
}

cocoaPods {
podsProject = file("../ios-app/Pods/Pods.xcodeproj") // here should be path to Pods xcode project
// here should be path to Pods xcode project
podsProject = file("../ios-app/Pods/Pods.xcodeproj")

pod("TensorFlowLiteObjC", module = "TFLTensorFlowLite", onlyLink = true)
}

```
Also add fraemwork location resolver into your project.gradle.kts
```kotlin
kotlin.targets
.filterIsInstance<org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarget>()
.flatMap { it.binaries }
Expand All @@ -65,7 +96,7 @@ kotlin.targets

Podfile
```ruby
pod 'mokoTensorflow', :git => 'https://github.com/icerockdev/moko-tensorflow.git', :tag => 'release/0.3.0'
pod 'mokoTensorflow', :git => 'https://github.com/icerockdev/moko-tensorflow.git', :tag => 'release/<latest-version>'
```

## Usage
Expand Down Expand Up @@ -137,6 +168,11 @@ class ViewController: UIViewController {
}
```

## Pitfalls
1. Only Float32 is supported, but you can easily add your own type convertors
2. When using ObjCInterpreter you are required to convert any input data into NSData.
3. Because of IOS array allocation, only supported models with structure [N,X,Y,...,Z], where N is batch size

## Samples
Please see more examples in the [sample directory](sample).

Expand Down
2 changes: 1 addition & 1 deletion mokoTensorflow.podspec
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Pod::Spec.new do |spec|
spec.source_files = "tensorflow/src/iosMain/swift/**/*.{h,m,swift}"
spec.resources = "tensorflow/src/iosMain/bundle/**/*"

spec.dependency 'TensorFlowLiteObjC', '0.0.1-nightly.20230212'
spec.dependency 'TensorFlowLiteObjC', '2.12.0'

spec.ios.deployment_target = '11.0'
spec.swift_version = '5.0'
Expand Down
16 changes: 11 additions & 5 deletions sample/android-app/src/main/java/com/icerockdev/MainActivity.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ import androidx.lifecycle.lifecycleScope
import com.divyanshu.draw.widget.DrawView
import com.icerockdev.library.ResHolder
import com.icerockdev.library.TFDigitClassifier
import dev.icerock.moko.tensorflow.Interpreter
import dev.icerock.moko.tensorflow.InterpreterOptions
import dev.icerock.moko.tensorflow.JVMInterpreter
import dev.icerock.moko.tensorflow.NativeInput
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.util.concurrent.atomic.AtomicBoolean
Expand All @@ -28,7 +29,7 @@ class MainActivity : AppCompatActivity() {
private lateinit var clearButton: Button
private lateinit var predictedTextView: TextView

private lateinit var interpreter: Interpreter
private lateinit var interpreter: JVMInterpreter
private lateinit var digitClassifier: TFDigitClassifier

private val isInterpreterInited = AtomicBoolean(false)
Expand Down Expand Up @@ -60,7 +61,7 @@ class MainActivity : AppCompatActivity() {
true
}

interpreter = Interpreter(ResHolder.getDigitsClassifierModel(), InterpreterOptions(2, useNNAPI = true), this)
interpreter = JVMInterpreter(ResHolder.getDigitsClassifierModel(), InterpreterOptions(2, useNNAPI = true), this)
digitClassifier = TFDigitClassifier(interpreter, this.lifecycleScope)

digitClassifier.initialize()
Expand All @@ -82,8 +83,13 @@ class MainActivity : AppCompatActivity() {
digitClassifier.inputImageHeight,
true
)

digitClassifier.classifyAsync(convertBitmapToByteBuffer(bitmapToClassify)) {
val byteBuffer = convertBitmapToByteBuffer(bitmapToClassify)
// digitClassifier.classifyAsync(byteBuffer) {
// runOnUiThread {
// predictedTextView.text = it
// }
// }
digitClassifier.classifyNativeAsync(NativeInput(byteBuffer)) {
runOnUiThread {
predictedTextView.text = it
}
Expand Down
24 changes: 12 additions & 12 deletions sample/ios-app/Podfile.lock
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
PODS:
- mokoTensorflow (0.4.0):
- TensorFlowLiteObjC (= 0.0.1-nightly.20230212)
- mokoTensorflow (0.3.0):
- TensorFlowLiteObjC (= 2.12.0)
- MultiPlatformLibrary (0.1.0)
- Sketch (3.0)
- TensorFlowLiteC (0.0.1-nightly.20230212):
- TensorFlowLiteC/Core (= 0.0.1-nightly.20230212)
- TensorFlowLiteC/Core (0.0.1-nightly.20230212)
- TensorFlowLiteObjC (0.0.1-nightly.20230212):
- TensorFlowLiteObjC/Core (= 0.0.1-nightly.20230212)
- TensorFlowLiteObjC/Core (0.0.1-nightly.20230212):
- TensorFlowLiteC (= 0.0.1-nightly.20230212)
- TensorFlowLiteC (2.12.0):
- TensorFlowLiteC/Core (= 2.12.0)
- TensorFlowLiteC/Core (2.12.0)
- TensorFlowLiteObjC (2.12.0):
- TensorFlowLiteObjC/Core (= 2.12.0)
- TensorFlowLiteObjC/Core (2.12.0):
- TensorFlowLiteC (= 2.12.0)

DEPENDENCIES:
- mokoTensorflow (from `../..`)
Expand All @@ -29,11 +29,11 @@ EXTERNAL SOURCES:
:path: "../mpp-library"

SPEC CHECKSUMS:
mokoTensorflow: f31dd35d9c68c098aa842b2885ded86e24b91a83
mokoTensorflow: 3b3781b48d0b8822a9a5811f6555be26156f7b4b
MultiPlatformLibrary: 91d3837ea2c0943e0713f98671a36913470ef412
Sketch: 49a4b71f7bc77316ed5f75ee79dedaa2b844d5e7
TensorFlowLiteC: 131cd06718a81ace70d56f10b1404157ce40d7fc
TensorFlowLiteObjC: 2e5cf40e720254b0905e09a7f638e7a2cf935727
TensorFlowLiteC: 20785a69299185a379ba9852b6625f00afd7984a
TensorFlowLiteObjC: 9a46a29a76661c513172cfffd3bf712b11ef25c3

PODFILE CHECKSUM: 5d66f0fb585809b01037efa8b5383d224d5dca98

Expand Down
2 changes: 1 addition & 1 deletion sample/ios-app/tensorflow-test/ViewController.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class ViewController: UIViewController, SketchViewDelegate {
)
let modelFileRes: ResourcesFileResource = ResHolder().getDigitsClassifierModel()

interpreter = Interpreter(
interpreter = ObjCInterpreter(
fileResource: modelFileRes,
options: options
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package com.icerockdev.library

import dev.icerock.moko.tensorflow.Interpreter
import dev.icerock.moko.tensorflow.NativeInput
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
Expand All @@ -26,12 +27,24 @@ class TFDigitClassifier(
inputImageWidth = inputShape[1]
inputImageHeight = inputShape[2]
modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE
interpreter.allocateTensors()
}

fun classifyAsync(inputData: Any, onResult: (String) -> Unit) {
scope.launch(Dispatchers.Default) {
val result = Array(1) { FloatArray(OUTPUT_CLASSES_COUNT) }
interpreter.run(listOf(inputData), mapOf(Pair(0, result)))
interpreter.run(arrayOf(inputData), mapOf(Pair(0, result)))

val maxIndex = result[0].indices.maxByOrNull { result[0][it] } ?: -1
val strResult = "Prediction Result: $maxIndex\nConfidence: ${result[0][maxIndex]}"

onResult(strResult)
}
}
fun classifyNativeAsync(nativeInput: NativeInput, onResult: (String) -> Unit) {
scope.launch(Dispatchers.Default) {
val result = Array(1) { FloatArray(OUTPUT_CLASSES_COUNT) }
interpreter.run(nativeInput, result)

val maxIndex = result[0].indices.maxByOrNull { result[0][it] } ?: -1
val strResult = "Prediction Result: $maxIndex\nConfidence: ${result[0][maxIndex]}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ package dev.icerock.moko.tensorflow
import android.content.Context
import dev.icerock.moko.resources.FileResource

actual class Interpreter(
actual val fileResource: FileResource,
actual val options: InterpreterOptions,
class JVMInterpreter(
override val fileResource: FileResource,
override val options: InterpreterOptions,
context: Context
) {
) : Interpreter {

private val tensorFlowInterpreter = PlatformInterpreter(
fileResource.openAsFile(context),
Expand All @@ -21,47 +21,61 @@ actual class Interpreter(
/**
* Gets the number of input tensors.
*/
actual fun getInputTensorCount(): Int = tensorFlowInterpreter.inputTensorCount
override fun getInputTensorCount(): Int = tensorFlowInterpreter.inputTensorCount

/**
* Gets the number of output Tensors.
*/
actual fun getOutputTensorCount(): Int = tensorFlowInterpreter.outputTensorCount
override fun getOutputTensorCount(): Int = tensorFlowInterpreter.outputTensorCount

/**
* Gets the Tensor associated with the provdied input index.
*
* @throws IllegalArgumentException if [index] is negative or is not smaller than the
* number of model inputs.
*/
actual fun getInputTensor(index: Int): Tensor = tensorFlowInterpreter.getInputTensor(index).toTensor()
override fun getInputTensor(index: Int): Tensor {
return tensorFlowInterpreter.getInputTensor(index).toTensor()
}

/**
* Gets the Tensor associated with the provdied output index.
*
* @throws IllegalArgumentException if [index] is negative or is not smaller than the
* number of model inputs.
*/
actual fun getOutputTensor(index: Int): Tensor = tensorFlowInterpreter.getOutputTensor(index).toTensor()
override fun getOutputTensor(index: Int): Tensor {
return tensorFlowInterpreter.getOutputTensor(index).toTensor()
}

/**
* Resizes [index] input of the native model to the given [shape].
*/
actual fun resizeInput(index: Int, shape: TensorShape) {
override fun resizeInput(index: Int, shape: TensorShape) {
tensorFlowInterpreter.resizeInput(index, shape.dimensions)
}

override fun allocateTensors() {
tensorFlowInterpreter.allocateTensors()
}

/**
* Runs model inference if the model takes multiple inputs, or returns multiple outputs.
*/
actual fun run(inputs: List<Any>, outputs: Map<Int, Any>) {
tensorFlowInterpreter.runForMultipleInputsOutputs(inputs.toTypedArray(), outputs)
override fun run(inputs: Array<*>, outputs: Map<Int, Any>) {
tensorFlowInterpreter.runForMultipleInputsOutputs(inputs, outputs)
}

override fun run(nativeInput: NativeInput, output: Array<*>) {
val inputs = arrayOf(nativeInput.byteBuffer)
val outputs = mapOf(Interpreter.OUTPUT_KEY to output)
run(inputs, outputs)
}

/**
* Release resources associated with the [Interpreter].
* Release resources associated with the [JVMInterpreter].
*/
actual fun close() {
override fun close() {
tensorFlowInterpreter.close()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package dev.icerock.moko.tensorflow

import java.nio.ByteBuffer

actual class NativeInput(val byteBuffer: ByteBuffer)
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ package dev.icerock.moko.tensorflow

import dev.icerock.moko.resources.FileResource

expect class Interpreter {

interface Interpreter {
val fileResource: FileResource
val options: InterpreterOptions

Expand Down Expand Up @@ -41,14 +40,35 @@ expect class Interpreter {
* Resizes [index] input of the native model to the given [shape].
*/
fun resizeInput(index: Int, shape: TensorShape)
fun allocateTensors()

/**
* Runs model inference if the model takes multiple inputs, or returns multiple outputs.
*
* In case with ios [outputs] required to be the { 0: Array<Any> } structure
*
* In case with ios [inputs] required to be the Array<NSData> structure
*/
@Deprecated("This approach may work differently on ios and android platform. Use run with NativeInput")
fun run(inputs: Array<*>, outputs: Map<Int, Any>)

/**
* Runs model inference with native input data
*
* @param nativeInput - NSData or java's ByteBuffer
* @param output - required output array
*/
fun run(inputs: List<Any>, outputs: Map<Int, Any>)
fun run(nativeInput: NativeInput, output: Array<*>)

/**
* Release resources associated with the [Interpreter].
*/
fun close()

companion object {
/**
* This is static output key which should be used when adding outputs data in [Interpreter.run]
*/
const val OUTPUT_KEY = 0
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package dev.icerock.moko.tensorflow

expect class NativeInput
Loading