diff --git a/examples/text-classification/TextClassificationDemo.v0.tsx b/examples/text-classification/TextClassificationDemo.v0.tsx
new file mode 100644
index 000000000..f3cc7e04c
--- /dev/null
+++ b/examples/text-classification/TextClassificationDemo.v0.tsx
@@ -0,0 +1,24 @@
+/**
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ *
+ * @format
+ */
+
+import * as React from 'react';
+import { Button, Text, TextInput, View } from 'react-native';
+import { useSafeAreaInsets } from 'react-native-safe-area-context';
+
+export default function TextClassificationDemo() {
+ // Get safe area insets to account for notches, etc.
+ const insets = useSafeAreaInsets();
+ return (
+
+
+
+ );
+}
diff --git a/examples/text-classification/TextClassificationDemo.v1.tsx b/examples/text-classification/TextClassificationDemo.v1.tsx
new file mode 100644
index 000000000..5963976aa
--- /dev/null
+++ b/examples/text-classification/TextClassificationDemo.v1.tsx
@@ -0,0 +1,47 @@
+/**
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ *
+ * @format
+ */
+
+import * as React from 'react';
+import { Button, StyleSheet, Text, TextInput, View } from 'react-native';
+import { useSafeAreaInsets } from 'react-native-safe-area-context';
+
+export default function TextClassificationDemo() {
+ // Get safe area insets to account for notches, etc.
+ const insets = useSafeAreaInsets();
+ return (
+
+
+
+ );
+}
+
+const styles = StyleSheet.create({
+ container: {
+ padding: 10,
+ },
+ item: {
+ margin: 10,
+ padding: 10,
+ },
+ input: {
+ borderWidth: 1,
+ color: '#000',
+ },
+});
diff --git a/examples/text-classification/TextClassificationDemo.v2.tsx b/examples/text-classification/TextClassificationDemo.v2.tsx
new file mode 100644
index 000000000..148c133dd
--- /dev/null
+++ b/examples/text-classification/TextClassificationDemo.v2.tsx
@@ -0,0 +1,60 @@
+/**
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ *
+ * @format
+ */
+
+import * as React from 'react';
+import { useState } from 'react';
+import { Button, StyleSheet, Text, TextInput, View } from 'react-native';
+import { useSafeAreaInsets } from 'react-native-safe-area-context';
+
+export default function TextClassificationDemo() {
+ // Get safe area insets to account for notches, etc.
+ const insets = useSafeAreaInsets();
+
+ const [text, setText] = useState('');
+ const [question, setQuestion] = useState('');
+
+ async function handleAsk() {
+ console.log({
+ text,
+ });
+ }
+
+ return (
+
+
+
+ Text Classification
+
+ );
+}
+
+const styles = StyleSheet.create({
+ container: {
+ padding: 10,
+ },
+ item: {
+ margin: 10,
+ padding: 10,
+ },
+ input: {
+ borderWidth: 1,
+ color: '#000',
+ },
+});
diff --git a/examples/text-classification/TextClassificationDemo.v3.tsx b/examples/text-classification/TextClassificationDemo.v3.tsx
new file mode 100644
index 000000000..6a31281d6
--- /dev/null
+++ b/examples/text-classification/TextClassificationDemo.v3.tsx
@@ -0,0 +1,75 @@
+/**
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ *
+ * @format
+ */
+
+import * as React from 'react';
+import { useState } from 'react';
+import { Button, StyleSheet, Text, TextInput, View } from 'react-native';
+import { MobileModel } from 'react-native-pytorch-core';
+import { useSafeAreaInsets } from 'react-native-safe-area-context';
+
+const model = require('../../models/bert_qa.ptl');
+
+type TextClassificationResult = {
+ sentiment: number;
+};
+
+export default function TextClassificationDemo() {
+ // Get safe area insets to account for notches, etc.
+ const insets = useSafeAreaInsets();
+
+ const [text, setText] = useState('');
+
+ async function handleClassify() {
+ const _text = `[CLS] ${text} [SEP]`;
+
+ const inferenceResult = await MobileModel.execute(
+ model,
+ {
+ text: _text,
+ modelInputLength: 360,
+ },
+ );
+
+ // Log model inference result to Metro console
+ console.log(inferenceResult);
+ }
+
+ return (
+
+
+
+ Text Classification
+
+ );
+}
+
+const styles = StyleSheet.create({
+ container: {
+ padding: 10,
+ },
+ item: {
+ margin: 10,
+ padding: 10,
+ },
+ input: {
+ borderWidth: 1,
+ color: '#000',
+ },
+});
diff --git a/examples/text-classification/TextClassificationDemo.v4.tsx b/examples/text-classification/TextClassificationDemo.v4.tsx
new file mode 100644
index 000000000..4a501be9f
--- /dev/null
+++ b/examples/text-classification/TextClassificationDemo.v4.tsx
@@ -0,0 +1,77 @@
+/**
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ *
+ * @format
+ */
+
+import * as React from 'react';
+import { useState } from 'react';
+import { Button, StyleSheet, Text, TextInput, View } from 'react-native';
+import { MobileModel } from 'react-native-pytorch-core';
+import { useSafeAreaInsets } from 'react-native-safe-area-context';
+
+const model = require('../../models/bert_qa.ptl');
+
+type TextClassificationResult = {
+ sentiment: number;
+};
+
+export default function TextClassificationDemo() {
+ // Get safe area insets to account for notches, etc.
+ const insets = useSafeAreaInsets();
+
+ const [text, setText] = useState('');
+ const [sentiment, setSentiment] = useState('');
+
+ async function handleClassify() {
+ const _text = `[CLS] ${text} [SEP]`;
+
+ const { result } = await MobileModel.execute(model, {
+ text: _text,
+ modelInputLength: 360,
+ });
+
+ // No answer found if the answer is null
+ if (result.sentiment == null) {
+ setSentiment('Could not find sentiment.');
+ } else {
+ setSentiment(result.answer);
+ }
+ }
+
+ return (
+
+
+
+ {sentiment}
+
+ );
+}
+
+const styles = StyleSheet.create({
+ container: {
+ padding: 10,
+ },
+ item: {
+ margin: 10,
+ padding: 10,
+ },
+ input: {
+ borderWidth: 1,
+ color: '#000',
+ },
+});
diff --git a/examples/text-classification/TextClassificationDemo.v5.tsx b/examples/text-classification/TextClassificationDemo.v5.tsx
new file mode 100644
index 000000000..d9a2771c1
--- /dev/null
+++ b/examples/text-classification/TextClassificationDemo.v5.tsx
@@ -0,0 +1,84 @@
+/**
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ *
+ * @format
+ */
+
+import * as React from 'react';
+import { useState } from 'react';
+import { Button, StyleSheet, Text, TextInput, View } from 'react-native';
+import { MobileModel } from 'react-native-pytorch-core';
+import { useSafeAreaInsets } from 'react-native-safe-area-context';
+
+const model = require('../../models/bert_qa.ptl');
+
+type TextClassificationResult = {
+ sentiment: number;
+};
+
+export default function TextClassificationDemo() {
+ // Get safe area insets to account for notches, etc.
+ const insets = useSafeAreaInsets();
+
+ const [text, setText] = useState('');
+ const [sentiment, setSentiment] = useState('');
+ const [isProcessing, setIsProcessing] = useState(false);
+
+ async function handleAsk() {
+ setIsProcessing(true);
+
+ const _text = `[CLS] ${text} [SEP]`;
+
+ const { result } = await MobileModel.execute(model, {
+ text: _text,
+ modelInputLength: 360,
+ });
+
+ // No answer found if the answer is null
+ if (result.sentiment == null) {
+ setSentiment('No answer found');
+ } else {
+ setSentiment(result.sentiment);
+ }
+
+ setIsProcessing(false);
+ }
+
+ return (
+
+
+
+
+ {isProcessing ? 'Predicting sentiment' : sentiment}
+
+
+ );
+}
+
+const styles = StyleSheet.create({
+ container: {
+ padding: 10,
+ },
+ item: {
+ margin: 10,
+ padding: 10,
+ },
+ input: {
+ borderWidth: 1,
+ color: '#000',
+ },
+});
diff --git a/website/docs/tutorials/text-classification.mdx b/website/docs/tutorials/text-classification.mdx
new file mode 100644
index 000000000..e26f48ab7
--- /dev/null
+++ b/website/docs/tutorials/text-classification.mdx
@@ -0,0 +1,279 @@
+---
+id: text-classification
+sidebar_position: 4
+---
+
+import ExampleDiffCodeTabs from '@site/src/components/ExampleDiffCodeTabs';
+import TextClassificationDemoExamples from '@site/src/components/examples/TextClassificationDemoExamples';
+import Tabs from '@theme/Tabs';
+import TabItem from '@theme/TabItem';
+import SurveyLinkButton from '@site/src/components/SurveyLinkButton';
+
+# Text Classification
+
+
+
+### In this tutorial, you will integrate an on-device NLP (Natural Language Processing) model that can classify sentences into various sentiments.
+
+If you haven't installed the PyTorch Live CLI yet, please [follow this tutorial](./get-started.mdx) to get started.
+
+If you get lost at any point in this tutorial, completed examples of each step can be found [here](https://github.com/pytorch/live/tree/main/examples/text-classification/).
+
+## Initialize New Project
+
+Let's start by initializing a new project `TextClassificationTutorial` with the PyTorch Live CLI.
+
+```shell
+npx torchlive-cli init TextClassificationTutorial
+```
+
+:::note
+
+The project init can take a few minutes depending on your internet connection and your computer.
+
+:::
+
+After completion, navigate to the `TextClassificationTutorial` directory created by the `init` command.
+
+```shell
+cd TextClassificationTutorial
+```
+
+### Run the project in the Android emulator or iOS Simulator
+
+The `run-android` and `run-ios` commands from the PyTorch Live CLI allow you to run the text classification project in the Android emulator or iOS Simulator.
+
+
+
+
+ ```shell
+ npx torchlive-cli run-android
+ ```
+
+ The app will deploy and run on your physical Android device if it is connected to your computer via USB, and it is in developer mode. There are more details on that in the [Get Started tutorial](./get-started.mdx).
+
+ ![](/img/tutorial/text-classification/android/first-run.png "Screenshot of app after fresh init with CLI")
+
+
+
+
+ ```shell
+ npx torchlive-cli run-ios
+ ```
+
+ ![](/img/tutorial/text-classification/ios/first-run.png "Screenshot of app after fresh init with CLI")
+
+
+
+
+:::tip
+
+Keep the app open and running! Any code change will immediately be reflected after saving.
+
+:::
+
+## Text Classification Demo
+
+Let's get started with the UI for the text classification. Go ahead and start by copying the following code into the file `src/demos/MyDemos.tsx`:
+
+:::note
+
+The `MyDemos.tsx` already contains code. Replace the code with the code below.
+
+:::
+
+
+
+The initial code creates a component rendering a text input, a button, and a text with `Text Classification`.
+
+
+
+
+ ![](/img/tutorial/text-classification/android/initial-ui.png "Screenshot of initial user interface")
+
+
+
+
+ ![](/img/tutorial/text-classification/ios/initial-ui.png "Screenshot of initial user interface")
+
+
+
+
+### Style the component
+
+Great! Let's, add some basic styling to the app UI. The styles will add a padding of `10` pixels for the container `View` component. It will also add padding and margin to the `TextInput` component, the ask `Button` and the output `Text`, so they aren't squeezed together.
+
+
+
+
+
+
+
+
+
+ ![](/img/tutorial/text-classification/android/simple-styles.png "Screenshot after applying simple component styles")
+
+
+
+
+ ![](/img/tutorial/text-classification/ios/simple-styles.png "Screenshot after applying simple component styles")
+
+
+
+
+### Add state and event handler
+
+Next, add state to the text input, to keep track of the input content. React provides the `useState` hook to save component state. A `useState` hook returns an array with two items (or tuple). The first item (index `0`) is the current state and the second item (index `1`) is the set state function to update the state. In this change, it uses a `useState` hook for the `text` state.
+
+Add an event handler to the `Classify` button. The event handler `handleClassify` will be called when the button is pressed. For testing, let's log the `text` state to the console (i.e., it will log what is typed into the text input).
+
+
+
+
+
+
+
+
+
+ ![](/img/tutorial/text-classification/android/event-handler-and-more-style.png "Screenshot event handler and more styling")
+
+
+
+
+ ![](/img/tutorial/text-classification/ios/event-handler-and-more-style.png "Screenshot event handler and more styling")
+
+
+
+
+Type into the text input, click the `Classify` button, and check logged output in terminal.
+
+![](/img/tutorial/text-classification/metro-console-log.png)
+
+### Run model inference
+
+Fantastic! Now let's use the text to run inference on a text classification model.
+
+We'll require the Text Classifier model trained using Lightning Flash - a Distilbert model (i.e., `flash_bert.ptl`) and add the `TextClassificationResult` type for type-safety. Then, we call the `execute` function on the `MobileModel` object with the model as first argument and an object with the `text` and the `modelInputLength` second argument.
+
+:::info
+
+The `text` state is concatenated with two special tokens `[CLS]` and `[SEP]`. This sequence format is how the model was trained and is expected as input. The first token of every sequence is always a special classification token (i.e., `[CLS]`).
+
+The `modelInputLength` defines the max token size. Simply speaking, an ML model works with numbers (i.e., tensors), so the sequence has to be converted into numbers. This is handled by PyTorch Live transparently. The `flash_bert.ptl` model uses a subword-based tokenizer to split words into smaller subwords that are mapped to numbers using a dictionary. For example, the word "hello" is first tokenized into the two tokens "hell" and "##o", and then both tokens are looked up in a pre-defined vocabulary that maps tokens to numbers. In consequence the word, "hello" is based on two tokens.
+
+Note that `360` is the maximum token size for the `text` (and the special tokens). It may be faster to reduce the `modelInputLength` to less than `360` if it is known that the use case works with short sequences.
+
+:::
+
+:::note
+
+It will use the `flash_bert.ptl` model that is already prepared for PyTorch Live. You can follow the [Prepare Custom Model](./prepare-custom-model.mdx) tutorial to prepare your own NLP model and use this model instead for text classification.
+
+:::
+
+Don't forget the `await` keyword for the `MobileModel.execute` function call!
+
+Last, let's log the inference result to the console.
+
+
+
+
+
+
+![](/img/tutorial/text-classification/metro-console-log-inference-result.png)
+
+The logged inference result is a JavaScript object containing the inference result including the `sentiment` and inference metrics (i.e., inference time, pack time, unpack time, and total time).
+
+Can you guess what the `text` was for the returned `answer`?
+
+### Show the answer
+
+Ok! So, we have an `answer`. Instead of having the end-user looking at a console log, we will render the answer in the app. We'll add a state for the `answer` using a React Hook, and when an answer is returned, we'll set it using the `setAnswer` function.
+
+The user interface will automatically re-render whenever the `setAnswer` function is called with a new value, so you don't have to worry about calling anything else besides this function. On re-render, the `answer` variable will have this new value, so we can use it to render it on the screen.
+
+:::note
+
+The `React.useState` is a React Hook. Hooks allow React function components, like our `TextClassificationDemo` function component, to remember things.
+
+For more information on [React Hooks](https://reactjs.org/docs/hooks-intro.html), head over to the React docs where you can read or watch explanations.
+
+:::
+
+
+
+
+
+
+
+
+
+ ![](/img/tutorial/text-classification/android/inference-example.png "Screenshot of text classification inference result")
+
+
+
+
+ ![](/img/tutorial/text-classification/ios/inference-example.png "Screenshot of text classification inference result")
+
+
+
+
+It looks like the model correctly answered the question!
+
+### Add user feedback
+
+It can take a few milliseconds for the model to return the answer. Let's add a `isProcessing` state which is `true` when the inference is running and `false` otherwise. The `isProcessing` is used to render "Looking for the answer" while the model inference is running and it will render the answer when it is done.
+
+
+
+
+
+
+
+
+
+ ![](/img/tutorial/text-classification/android/user-feedback.gif "Screenshot for showing text classification with user feedback")
+
+
+
+
+ ![](/img/tutorial/text-classification/ios/user-feedback.gif "Screenshot for showing text classification with user feedback")
+
+
+
+
+## Give us feedback
+
+
+
+
diff --git a/website/scripts/codegen-examples.js b/website/scripts/codegen-examples.js
index f382e5023..0fc4861f4 100644
--- a/website/scripts/codegen-examples.js
+++ b/website/scripts/codegen-examples.js
@@ -34,6 +34,10 @@ const EXAMPLE_SPECS = [
name: 'QuestionAnsweringDemo',
pathTemplate: idx => `question-answering/QuestionAnsweringDemo.v${idx}.tsx`,
},
+ {
+ name: 'TextClassificationDemo',
+ pathTemplate: idx => `text-classification/TextClassificationDemo.v${idx}.tsx`,
+ },
];
// This class generates a JS component wrapping the diffs and code for a set of examples.
@@ -128,10 +132,10 @@ function importExamples() {
);
if (fs.existsSync(destDir)) {
console.log(` Removing directory ${destDir}`);
- fs.rmSync(destDir, {recursive: true});
+ fs.rmSync(destDir, { recursive: true });
}
console.log(` Making directory ${destDir}`);
- fs.mkdirSync(destDir, {recursive: true});
+ fs.mkdirSync(destDir, { recursive: true });
// Copy each example from 0..N, trimming the header comment out and
// generating in-between diffs along the way, from 0->1, 1->2, etc.
@@ -148,9 +152,9 @@ function importExamples() {
}
console.log(` Writing trimmed ${destPath}`);
- const exampleContents = fs.readFileSync(sourcePath, {encoding: 'utf8'});
+ const exampleContents = fs.readFileSync(sourcePath, { encoding: 'utf8' });
const trimmedExampleContents = trimExampleFileHeader(exampleContents);
- fs.writeFileSync(destPath, trimmedExampleContents, {encoding: 'utf8'});
+ fs.writeFileSync(destPath, trimmedExampleContents, { encoding: 'utf8' });
moduleBuilder.addCodeBlock(i, destPath);
if (i > 0) {
@@ -159,10 +163,10 @@ function importExamples() {
const diffResult = child_process.spawnSync(
'diff',
['--unified', prevPath, destPath],
- {encoding: 'utf8'},
+ { encoding: 'utf8' },
);
const trimmedDiffContents = trimDiffHeader(diffResult.stdout);
- fs.writeFileSync(diffPath, trimmedDiffContents, {encoding: 'utf8'});
+ fs.writeFileSync(diffPath, trimmedDiffContents, { encoding: 'utf8' });
moduleBuilder.addDiffBlock(i, diffPath);
}
}
@@ -174,11 +178,11 @@ function importExamples() {
if (!fs.existsSync(moduleDir)) {
console.log(` Making directory ${moduleDir}`);
- fs.mkdirSync(moduleDir, {recursive: true});
+ fs.mkdirSync(moduleDir, { recursive: true });
}
console.log(` Writing React components to ${modulePath}`);
- fs.writeFileSync(modulePath, moduleContents, {encoding: 'utf8'});
+ fs.writeFileSync(modulePath, moduleContents, { encoding: 'utf8' });
}
}
diff --git a/website/src/components/examples/TextClassificationDemoExamples.js b/website/src/components/examples/TextClassificationDemoExamples.js
new file mode 100644
index 000000000..09731a8f7
--- /dev/null
+++ b/website/src/components/examples/TextClassificationDemoExamples.js
@@ -0,0 +1,127 @@
+/**
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ *
+ * @format
+ */
+
+// This file was generated by `yarn run codegen-examples`
+
+import React from 'react';
+import CodeBlock from '@theme/CodeBlock';
+
+function TextClassificationDemoExamples() { }
+
+import V0Contents from '!!raw-loader!/static/examples/text-classification/TextClassificationDemo.v0.tsx';
+
+import V1Contents from '!!raw-loader!/static/examples/text-classification/TextClassificationDemo.v1.tsx';
+
+import V1DiffContents from '!!raw-loader!/static/examples/text-classification/TextClassificationDemo.v1.tsx.diff';
+
+import V2Contents from '!!raw-loader!/static/examples/text-classification/TextClassificationDemo.v2.tsx';
+
+import V2DiffContents from '!!raw-loader!/static/examples/text-classification/TextClassificationDemo.v2.tsx.diff';
+
+import V3Contents from '!!raw-loader!/static/examples/text-classification/TextClassificationDemo.v3.tsx';
+
+import V3DiffContents from '!!raw-loader!/static/examples/text-classification/TextClassificationDemo.v3.tsx.diff';
+
+import V4Contents from '!!raw-loader!/static/examples/text-classification/TextClassificationDemo.v4.tsx';
+
+import V4DiffContents from '!!raw-loader!/static/examples/text-classification/TextClassificationDemo.v4.tsx.diff';
+
+import V5Contents from '!!raw-loader!/static/examples/text-classification/TextClassificationDemo.v5.tsx';
+
+import V5DiffContents from '!!raw-loader!/static/examples/text-classification/TextClassificationDemo.v5.tsx.diff';
+
+TextClassificationDemoExamples.V0CodeBlock = function V0CodeBlock(props) {
+ return (
+
+ {V0Contents}
+
+ );
+};
+
+TextClassificationDemoExamples.V1CodeBlock = function V1CodeBlock(props) {
+ return (
+
+ {V1Contents}
+
+ );
+};
+
+TextClassificationDemoExamples.V1DiffBlock = function V1DiffBlock(props) {
+ return (
+
+ {V1DiffContents}
+
+ );
+};
+
+TextClassificationDemoExamples.V2CodeBlock = function V2CodeBlock(props) {
+ return (
+
+ {V2Contents}
+
+ );
+};
+
+TextClassificationDemoExamples.V2DiffBlock = function V2DiffBlock(props) {
+ return (
+
+ {V2DiffContents}
+
+ );
+};
+
+TextClassificationDemoExamples.V3CodeBlock = function V3CodeBlock(props) {
+ return (
+
+ {V3Contents}
+
+ );
+};
+
+TextClassificationDemoExamples.V3DiffBlock = function V3DiffBlock(props) {
+ return (
+
+ {V3DiffContents}
+
+ );
+};
+
+TextClassificationDemoExamples.V4CodeBlock = function V4CodeBlock(props) {
+ return (
+
+ {V4Contents}
+
+ );
+};
+
+TextClassificationDemoExamples.V4DiffBlock = function V4DiffBlock(props) {
+ return (
+
+ {V4DiffContents}
+
+ );
+};
+
+TextClassificationDemoExamples.V5CodeBlock = function V5CodeBlock(props) {
+ return (
+
+ {V5Contents}
+
+ );
+};
+
+TextClassificationDemoExamples.V5DiffBlock = function V5DiffBlock(props) {
+ return (
+
+ {V5DiffContents}
+
+ );
+};
+
+export default TextClassificationDemoExamples;
diff --git a/website/static/examples/text-classification/TextClassificationDemo.v0.tsx b/website/static/examples/text-classification/TextClassificationDemo.v0.tsx
new file mode 100644
index 000000000..2fe4fcbf4
--- /dev/null
+++ b/website/static/examples/text-classification/TextClassificationDemo.v0.tsx
@@ -0,0 +1,15 @@
+import * as React from 'react';
+import { Button, Text, TextInput, View } from 'react-native';
+import { useSafeAreaInsets } from 'react-native-safe-area-context';
+
+export default function TextClassificationDemo() {
+ // Get safe area insets to account for notches, etc.
+ const insets = useSafeAreaInsets();
+ return (
+
+
+
+ );
+}
diff --git a/website/static/examples/text-classification/TextClassificationDemo.v1.tsx b/website/static/examples/text-classification/TextClassificationDemo.v1.tsx
new file mode 100644
index 000000000..ca0ce91a9
--- /dev/null
+++ b/website/static/examples/text-classification/TextClassificationDemo.v1.tsx
@@ -0,0 +1,38 @@
+import * as React from 'react';
+import { Button, StyleSheet, Text, TextInput, View } from 'react-native';
+import { useSafeAreaInsets } from 'react-native-safe-area-context';
+
+export default function TextClassificationDemo() {
+ // Get safe area insets to account for notches, etc.
+ const insets = useSafeAreaInsets();
+ return (
+
+
+
+ );
+}
+
+const styles = StyleSheet.create({
+ container: {
+ padding: 10,
+ },
+ item: {
+ margin: 10,
+ padding: 10,
+ },
+ input: {
+ borderWidth: 1,
+ color: '#000',
+ },
+});
diff --git a/website/static/examples/text-classification/TextClassificationDemo.v1.tsx.diff b/website/static/examples/text-classification/TextClassificationDemo.v1.tsx.diff
new file mode 100644
index 000000000..5fd999a5a
--- /dev/null
+++ b/website/static/examples/text-classification/TextClassificationDemo.v1.tsx.diff
@@ -0,0 +1,43 @@
+@@ -1,15 +1,38 @@
+ import * as React from 'react';
+-import { Button, Text, TextInput, View } from 'react-native';
++import { Button, StyleSheet, Text, TextInput, View } from 'react-native';
+ import { useSafeAreaInsets } from 'react-native-safe-area-context';
+
+ export default function TextClassificationDemo() {
+ // Get safe area insets to account for notches, etc.
+ const insets = useSafeAreaInsets();
+ return (
+-
+-
++
++
+
+ );
+ }
++
++const styles = StyleSheet.create({
++ container: {
++ padding: 10,
++ },
++ item: {
++ margin: 10,
++ padding: 10,
++ },
++ input: {
++ borderWidth: 1,
++ color: '#000',
++ },
++});
diff --git a/website/static/examples/text-classification/TextClassificationDemo.v2.tsx b/website/static/examples/text-classification/TextClassificationDemo.v2.tsx
new file mode 100644
index 000000000..31b4fc462
--- /dev/null
+++ b/website/static/examples/text-classification/TextClassificationDemo.v2.tsx
@@ -0,0 +1,51 @@
+import * as React from 'react';
+import { useState } from 'react';
+import { Button, StyleSheet, Text, TextInput, View } from 'react-native';
+import { useSafeAreaInsets } from 'react-native-safe-area-context';
+
+export default function TextClassificationDemo() {
+ // Get safe area insets to account for notches, etc.
+ const insets = useSafeAreaInsets();
+
+ const [text, setText] = useState('');
+ const [question, setQuestion] = useState('');
+
+ async function handleAsk() {
+ console.log({
+ text,
+ });
+ }
+
+ return (
+
+
+
+ Text Classification
+
+ );
+}
+
+const styles = StyleSheet.create({
+ container: {
+ padding: 10,
+ },
+ item: {
+ margin: 10,
+ padding: 10,
+ },
+ input: {
+ borderWidth: 1,
+ color: '#000',
+ },
+});
diff --git a/website/static/examples/text-classification/TextClassificationDemo.v2.tsx.diff b/website/static/examples/text-classification/TextClassificationDemo.v2.tsx.diff
new file mode 100644
index 000000000..adf4440ab
--- /dev/null
+++ b/website/static/examples/text-classification/TextClassificationDemo.v2.tsx.diff
@@ -0,0 +1,37 @@
+@@ -1,10 +1,21 @@
+ import * as React from 'react';
++import { useState } from 'react';
+ import { Button, StyleSheet, Text, TextInput, View } from 'react-native';
+ import { useSafeAreaInsets } from 'react-native-safe-area-context';
+
+ export default function TextClassificationDemo() {
+ // Get safe area insets to account for notches, etc.
+ const insets = useSafeAreaInsets();
++
++ const [text, setText] = useState('');
++ const [question, setQuestion] = useState('');
++
++ async function handleAsk() {
++ console.log({
++ text,
++ });
++ }
++
+ return (
+
+
+-
+ );
diff --git a/website/static/examples/text-classification/TextClassificationDemo.v3.tsx b/website/static/examples/text-classification/TextClassificationDemo.v3.tsx
new file mode 100644
index 000000000..a86517006
--- /dev/null
+++ b/website/static/examples/text-classification/TextClassificationDemo.v3.tsx
@@ -0,0 +1,66 @@
+import * as React from 'react';
+import { useState } from 'react';
+import { Button, StyleSheet, Text, TextInput, View } from 'react-native';
+import { MobileModel } from 'react-native-pytorch-core';
+import { useSafeAreaInsets } from 'react-native-safe-area-context';
+
+const model = require('../../models/bert_qa.ptl');
+
+type TextClassificationResult = {
+ sentiment: number;
+};
+
+export default function TextClassificationDemo() {
+ // Get safe area insets to account for notches, etc.
+ const insets = useSafeAreaInsets();
+
+ const [text, setText] = useState('');
+
+ async function handleClassify() {
+ const _text = `[CLS] ${text} [SEP]`;
+
+ const inferenceResult = await MobileModel.execute(
+ model,
+ {
+ text: _text,
+ modelInputLength: 360,
+ },
+ );
+
+ // Log model inference result to Metro console
+ console.log(inferenceResult);
+ }
+
+ return (
+
+
+
+ Text Classification
+
+ );
+}
+
+const styles = StyleSheet.create({
+ container: {
+ padding: 10,
+ },
+ item: {
+ margin: 10,
+ padding: 10,
+ },
+ input: {
+ borderWidth: 1,
+ color: '#000',
+ },
+});
diff --git a/website/static/examples/text-classification/TextClassificationDemo.v3.tsx.diff b/website/static/examples/text-classification/TextClassificationDemo.v3.tsx.diff
new file mode 100644
index 000000000..93f7645fd
--- /dev/null
+++ b/website/static/examples/text-classification/TextClassificationDemo.v3.tsx.diff
@@ -0,0 +1,49 @@
+@@ -1,19 +1,34 @@
+ import * as React from 'react';
+ import { useState } from 'react';
+ import { Button, StyleSheet, Text, TextInput, View } from 'react-native';
++import { MobileModel } from 'react-native-pytorch-core';
+ import { useSafeAreaInsets } from 'react-native-safe-area-context';
+
++const model = require('../../models/bert_qa.ptl');
++
++type TextClassificationResult = {
++ sentiment: number;
++};
++
+ export default function TextClassificationDemo() {
+ // Get safe area insets to account for notches, etc.
+ const insets = useSafeAreaInsets();
+
+ const [text, setText] = useState('');
+- const [question, setQuestion] = useState('');
+
+- async function handleAsk() {
+- console.log({
+- text,
+- });
++ async function handleClassify() {
++ const _text = `[CLS] ${text} [SEP]`;
++
++ const inferenceResult = await MobileModel.execute(
++ model,
++ {
++ text: _text,
++ modelInputLength: 360,
++ },
++ );
++
++ // Log model inference result to Metro console
++ console.log(inferenceResult);
+ }
+
+ return (
+@@ -30,7 +45,7 @@
+ style={[styles.item, styles.input]}
+ value={text}
+ />
+-
++
+ Text Classification
+
+ );
diff --git a/website/static/examples/text-classification/TextClassificationDemo.v4.tsx b/website/static/examples/text-classification/TextClassificationDemo.v4.tsx
new file mode 100644
index 000000000..133a725cb
--- /dev/null
+++ b/website/static/examples/text-classification/TextClassificationDemo.v4.tsx
@@ -0,0 +1,68 @@
+import * as React from 'react';
+import { useState } from 'react';
+import { Button, StyleSheet, Text, TextInput, View } from 'react-native';
+import { MobileModel } from 'react-native-pytorch-core';
+import { useSafeAreaInsets } from 'react-native-safe-area-context';
+
+const model = require('../../models/bert_qa.ptl');
+
+type TextClassificationResult = {
+ sentiment: number;
+};
+
+export default function TextClassificationDemo() {
+ // Get safe area insets to account for notches, etc.
+ const insets = useSafeAreaInsets();
+
+ const [text, setText] = useState('');
+ const [sentiment, setSentiment] = useState('');
+
+ async function handleClassify() {
+ const _text = `[CLS] ${text} [SEP]`;
+
+ const { result } = await MobileModel.execute(model, {
+ text: _text,
+ modelInputLength: 360,
+ });
+
+ // No answer found if the answer is null
+ if (result.sentiment == null) {
+ setSentiment('Could not find sentiment.');
+ } else {
+ setSentiment(result.answer);
+ }
+ }
+
+ return (
+
+
+
+ {sentiment}
+
+ );
+}
+
+const styles = StyleSheet.create({
+ container: {
+ padding: 10,
+ },
+ item: {
+ margin: 10,
+ padding: 10,
+ },
+ input: {
+ borderWidth: 1,
+ color: '#000',
+ },
+});
diff --git a/website/static/examples/text-classification/TextClassificationDemo.v4.tsx.diff b/website/static/examples/text-classification/TextClassificationDemo.v4.tsx.diff
new file mode 100644
index 000000000..301088426
--- /dev/null
+++ b/website/static/examples/text-classification/TextClassificationDemo.v4.tsx.diff
@@ -0,0 +1,41 @@
+@@ -15,20 +15,22 @@
+ const insets = useSafeAreaInsets();
+
+ const [text, setText] = useState('');
++ const [sentiment, setSentiment] = useState('');
+
+ async function handleClassify() {
+ const _text = `[CLS] ${text} [SEP]`;
+
+- const inferenceResult = await MobileModel.execute(
+- model,
+- {
+- text: _text,
+- modelInputLength: 360,
+- },
+- );
++ const { result } = await MobileModel.execute(model, {
++ text: _text,
++ modelInputLength: 360,
++ });
+
+- // Log model inference result to Metro console
+- console.log(inferenceResult);
++ // No answer found if the answer is null
++ if (result.sentiment == null) {
++ setSentiment('Could not find sentiment.');
++ } else {
++ setSentiment(result.answer);
++ }
+ }
+
+ return (
+@@ -46,7 +48,7 @@
+ value={text}
+ />
+
+- Text Classification
++ {sentiment}
+
+ );
+ }
diff --git a/website/static/examples/text-classification/TextClassificationDemo.v5.tsx b/website/static/examples/text-classification/TextClassificationDemo.v5.tsx
new file mode 100644
index 000000000..b1382ef03
--- /dev/null
+++ b/website/static/examples/text-classification/TextClassificationDemo.v5.tsx
@@ -0,0 +1,75 @@
+import * as React from 'react';
+import { useState } from 'react';
+import { Button, StyleSheet, Text, TextInput, View } from 'react-native';
+import { MobileModel } from 'react-native-pytorch-core';
+import { useSafeAreaInsets } from 'react-native-safe-area-context';
+
+const model = require('../../models/bert_qa.ptl');
+
+type TextClassificationResult = {
+ sentiment: number;
+};
+
+export default function TextClassificationDemo() {
+ // Get safe area insets to account for notches, etc.
+ const insets = useSafeAreaInsets();
+
+ const [text, setText] = useState('');
+ const [sentiment, setSentiment] = useState('');
+ const [isProcessing, setIsProcessing] = useState(false);
+
+ async function handleAsk() {
+ setIsProcessing(true);
+
+ const _text = `[CLS] ${text} [SEP]`;
+
+ const { result } = await MobileModel.execute(model, {
+ text: _text,
+ modelInputLength: 360,
+ });
+
+ // No answer found if the answer is null
+ if (result.sentiment == null) {
+ setSentiment('No answer found');
+ } else {
+ setSentiment(result.sentiment);
+ }
+
+ setIsProcessing(false);
+ }
+
+ return (
+
+
+
+
+ {isProcessing ? 'Predicting sentiment' : sentiment}
+
+
+ );
+}
+
+const styles = StyleSheet.create({
+ container: {
+ padding: 10,
+ },
+ item: {
+ margin: 10,
+ padding: 10,
+ },
+ input: {
+ borderWidth: 1,
+ color: '#000',
+ },
+});
diff --git a/website/static/examples/text-classification/TextClassificationDemo.v5.tsx.diff b/website/static/examples/text-classification/TextClassificationDemo.v5.tsx.diff
new file mode 100644
index 000000000..682df6ba4
--- /dev/null
+++ b/website/static/examples/text-classification/TextClassificationDemo.v5.tsx.diff
@@ -0,0 +1,41 @@
+@@ -16,8 +16,11 @@
+
+ const [text, setText] = useState('');
+ const [sentiment, setSentiment] = useState('');
++ const [isProcessing, setIsProcessing] = useState(false);
++
++ async function handleAsk() {
++ setIsProcessing(true);
+
+- async function handleClassify() {
+ const _text = `[CLS] ${text} [SEP]`;
+
+ const { result } = await MobileModel.execute(model, {
+@@ -27,10 +30,12 @@
+
+ // No answer found if the answer is null
+ if (result.sentiment == null) {
+- setSentiment('Could not find sentiment.');
++ setSentiment('No answer found');
+ } else {
+- setSentiment(result.answer);
++ setSentiment(result.sentiment);
+ }
++
++ setIsProcessing(false);
+ }
+
+ return (
+@@ -47,8 +52,10 @@
+ style={[styles.item, styles.input]}
+ value={text}
+ />
+-
+- {sentiment}
++
++
++ {isProcessing ? 'Predicting sentiment' : sentiment}
++
+
+ );
+ }