From 2c0250c79d3b58ed606be0684720f475263a31ac Mon Sep 17 00:00:00 2001 From: Lucas Oskorep Date: Tue, 9 Aug 2022 19:29:57 -0400 Subject: [PATCH] beginning work on ui --- lib/main.dart | 13 +++++- lib/tflite/classifier.dart | 16 +++++-- lib/tflite/ml_isolate.dart | 13 +++++- lib/widgets/poke_finder.dart | 80 +++++++++++++++++++++++++-------- lib/widgets/tensordex_home.dart | 28 ++++++++++++ 5 files changed, 126 insertions(+), 24 deletions(-) diff --git a/lib/main.dart b/lib/main.dart index 17754b1..00ddb44 100644 --- a/lib/main.dart +++ b/lib/main.dart @@ -1,9 +1,20 @@ +import 'package:camera/camera.dart'; import 'package:flutter/material.dart'; +import 'package:flutter/services.dart'; import 'package:tensordex_mobile/widgets/tensordex_home.dart'; import 'package:tensordex_mobile/utils/logger.dart'; +late List cameras; + Future main() async { - runApp(const MyApp()); + WidgetsFlutterBinding.ensureInitialized(); + cameras = await availableCameras(); + SystemChrome.setPreferredOrientations([ + DeviceOrientation.portraitUp, + DeviceOrientation.portraitDown, + ]).then((_) { + runApp(const MyApp()); + }); } class MyApp extends StatelessWidget { diff --git a/lib/tflite/classifier.dart b/lib/tflite/classifier.dart index c16315e..e0a22e3 100644 --- a/lib/tflite/classifier.dart +++ b/lib/tflite/classifier.dart @@ -1,7 +1,9 @@ +import 'dart:io'; import 'dart:math'; import 'package:collection/collection.dart'; import 'package:image/image.dart' as image_lib; +import 'package:path_provider/path_provider.dart'; import 'package:tflite_flutter/tflite_flutter.dart'; import 'package:tflite_flutter_helper/tflite_flutter_helper.dart'; @@ -31,6 +33,7 @@ class Classifier { /// Labels file loaded as list late List _labels; int classifierCreationStart = -1; + bool _shouldReturnFrame = false; Classifier( Interpreter interpreter, { @@ -67,6 +70,10 @@ class Classifier { } } + void setReturnFrame(bool returnFrame) { + _shouldReturnFrame = returnFrame; + } + /// Pre-process the image TensorImage? getProcessedImage(TensorImage? inputImage) { int cropSize = min(_inputImage.height, _inputImage.width); @@ -83,12 +90,14 @@ class Classifier { } /// Runs object detection on the input image - Map? predict(image_lib.Image image) { + Future?> predict(image_lib.Image image) async { var preProcStart = DateTime.now().millisecondsSinceEpoch; - _inputImage.loadImage(image); + _inputImage.loadImage(image_lib.copyRotate(image, 90)); _inputImage = getProcessedImage(_inputImage)!; + + var inferenceStart = DateTime.now().millisecondsSinceEpoch; - _interpreter.run(_inputImage.buffer, _outputBuffer.getBuffer()); + _interpreter.run(_inputImage.buffer, _outputBuffer.getBuffer()); var postProcStart = DateTime.now().millisecondsSinceEpoch; Map labeledProb = TensorLabel.fromList(labels, _outputProcessor.process(_outputBuffer)) @@ -106,6 +115,7 @@ class Classifier { inferenceTime: postProcStart - inferenceStart, postProcessingTime: endTime - postProcStart, ), + 'image': _shouldReturnFrame? _inputImage.image : null, }; } diff --git a/lib/tflite/ml_isolate.dart b/lib/tflite/ml_isolate.dart index 791759d..95524c7 100644 --- a/lib/tflite/ml_isolate.dart +++ b/lib/tflite/ml_isolate.dart @@ -5,6 +5,7 @@ import 'package:tensordex_mobile/tflite/classifier.dart'; import 'package:tflite_flutter/tflite_flutter.dart'; import '../utils/image_utils.dart'; +import '../utils/logger.dart'; class IsolateBase { final ReceivePort _receivePort = ReceivePort(); @@ -25,6 +26,7 @@ class MLIsolate extends IsolateBase { _sendPort = await _receivePort.first; } + static void entryPoint(SendPort sendPort) async { final port = ReceivePort(); sendPort.send(port.sendPort); @@ -33,10 +35,15 @@ class MLIsolate extends IsolateBase { var cameraImage = mlIsolateData.cameraImage; var converted = ImageUtils.convertCameraImage(cameraImage); if (converted != null) { - Classifier classifier = Classifier( + var classifier = Classifier( Interpreter.fromAddress(mlIsolateData.interpreterAddress), labels: mlIsolateData.labels); - var result = classifier.predict(converted); + if (classifier.interpreter.address != + mlIsolateData.interpreterAddress) { + logger.e('INTERPRETER ADDRESS MISMATCH!'); + } + classifier.setReturnFrame(mlIsolateData.shouldSaveFrame); + var result = await classifier.predict(converted); mlIsolateData.responsePort?.send(result); } else { mlIsolateData.responsePort?.send({'response': 'not working yet'}); @@ -49,6 +56,7 @@ class MLIsolate extends IsolateBase { class MLIsolateData { CameraImage cameraImage; int interpreterAddress; + bool shouldSaveFrame; List labels; SendPort? responsePort; @@ -56,5 +64,6 @@ class MLIsolateData { this.cameraImage, this.interpreterAddress, this.labels, + this.shouldSaveFrame, ); } diff --git a/lib/widgets/poke_finder.dart b/lib/widgets/poke_finder.dart index 9df9633..5fa0f7d 100644 --- a/lib/widgets/poke_finder.dart +++ b/lib/widgets/poke_finder.dart @@ -1,19 +1,21 @@ import 'dart:convert'; +import 'dart:io'; import 'dart:isolate'; - import 'package:camera/camera.dart'; import 'package:flutter/material.dart'; import 'package:flutter/services.dart'; +import 'package:image/image.dart'; +import 'package:path_provider/path_provider.dart'; import 'package:tensordex_mobile/tflite/ml_isolate.dart'; import 'package:tensordex_mobile/tflite/model/configuration.dart'; import 'package:tensordex_mobile/tflite/model/outputs/stats.dart'; import 'package:tflite_flutter/tflite_flutter.dart'; +import '../main.dart'; import '../tflite/classifier.dart'; import '../tflite/model/outputs/recognition.dart'; import '../utils/logger.dart'; - class PokeFinder extends StatefulWidget { /// Callback to pass results after inference to [HomeView] final Function(List recognitions) resultsCallback; @@ -35,9 +37,9 @@ class _PokeFinderState extends State with WidgetsBindingObserver { bool predicting = false; bool _cameraInitialized = false; bool _classifierInitialized = false; + bool _saveClassifierImage = false; + int cameraIndex = 0; - //cameras - late List cameras; late CameraController cameraController; //ml variables @@ -56,7 +58,10 @@ class _PokeFinderState extends State with WidgetsBindingObserver { WidgetsBinding.instance.addObserver(this); _mlIsolate = MLIsolate(); await _mlIsolate.start(); - initializeCamera(); + swapToCamera(cameras[0]); + for (CameraDescription cam in cameras) { + logger.i(cam); + } initializeModel(); predicting = false; } @@ -86,16 +91,12 @@ class _PokeFinderState extends State with WidgetsBindingObserver { options: config.interpreters[0]); } - /// Initializes the camera by setting [cameraController] - void initializeCamera() async { - cameras = await availableCameras(); - cameraController = - CameraController(cameras[0], ResolutionPreset.low, enableAudio: false); - + void swapToCamera(CameraDescription cameraDescription) async { + cameraController = CameraController(cameraDescription, ResolutionPreset.low, + enableAudio: false); cameraController.initialize().then((_) async { /// previewSize is size of each image frame captured by controller /// 352x288 on iOS, 240p (320x240) on Android with ResolutionPreset.low - // Stream of image passed to [onLatestImageAvailable] callback await cameraController.startImageStream(onLatestImageAvailable); setState(() { _cameraInitialized = true; @@ -112,8 +113,12 @@ class _PokeFinderState extends State with WidgetsBindingObserver { setState(() { predicting = true; }); + logger.i(_saveClassifierImage); var results = await inference(MLIsolateData( - cameraImage, classifier.interpreter.address, classifier.labels)); + cameraImage, + classifier.interpreter.address, + classifier.labels, + _saveClassifierImage)); if (results.containsKey('recognitions')) { widget.resultsCallback(results['recognitions']); @@ -121,23 +126,62 @@ class _PokeFinderState extends State with WidgetsBindingObserver { if (results.containsKey('stats')) { widget.statsCallback(results['stats']); } - logger.i(results); - + if (results.containsKey('image')) { + var image = results['image']; + if (image != null) { + Directory tempDir = await getTemporaryDirectory(); + String tempPath = tempDir.path; + logger.i(tempPath); + logger.i('SAVING IMAGE!'); + await File('$tempPath/${DateTime.now().millisecondsSinceEpoch}.png') + .writeAsBytes(encodePng(image)); + _saveClassifierImage = false; + } + } setState(() { predicting = false; }); } } + void swapCamera() async { + logger.i(cameras); + logger.i(cameraIndex); + cameraIndex += 1; + if (cameras.length <= cameraIndex) { + cameraIndex = 0; + } + swapToCamera(cameras[cameraIndex]); + } + + void saveMLImage() async { + logger.i('setting save classifier to true'); + _saveClassifierImage = true; + } + + void setZoom() async { + logger.i(await cameraController.getMinZoomLevel()); + logger.i(await cameraController.getMaxZoomLevel()); + logger.i(cameraController.setZoomLevel(2.0)); + } + @override Widget build(BuildContext context) { // Return empty container while the camera is not initialized if (!_cameraInitialized) { return Container(); } - return AspectRatio( - aspectRatio: 1 / cameraController.value.aspectRatio, - child: CameraPreview(cameraController)); + return Column( + children: [ + AspectRatio( + aspectRatio: 1 / cameraController.value.aspectRatio, + child: CameraPreview(cameraController)), + TextButton(onPressed: swapCamera, child: const Text('Change Camera!')), + TextButton( + onPressed: saveMLImage, child: const Text('Save Model Image')), + TextButton(onPressed: setZoom, child: const Text('Zoom!')) + ], + ); } /// Runs inference in another isolate diff --git a/lib/widgets/tensordex_home.dart b/lib/widgets/tensordex_home.dart index e65cc99..f81b78e 100644 --- a/lib/widgets/tensordex_home.dart +++ b/lib/widgets/tensordex_home.dart @@ -64,6 +64,34 @@ class _TensordexHomeState extends State { appBar: AppBar( title: Text(widget.title), ), + drawer: Drawer( + child: ListView( + // Important: Remove any padding from the ListView. + padding: EdgeInsets.zero, + children: [ + const DrawerHeader( + decoration: BoxDecoration( + color: Colors.blue, + ), + child: Text('Drawer Header'), + ), + ListTile( + title: const Text('Item 1'), + onTap: () { + // Update the state of the app. + // ... + }, + ), + ListTile( + title: const Text('Item 2'), + onTap: () { + // Update the state of the app. + // ... + }, + ), + ], + ), + ), body: Center( child: Column( mainAxisAlignment: MainAxisAlignment.start,