Skip to content

Commit

Permalink
WIP local SLM samples
Browse files Browse the repository at this point in the history
  • Loading branch information
captainbrosset committed Jul 15, 2024
1 parent 6d89893 commit 649affc
Show file tree
Hide file tree
Showing 17 changed files with 67,498 additions and 0 deletions.
52 changes: 52 additions & 0 deletions on-device-ai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Local Chatbot in the browser using Phi3, ONNX Runtime Web and WebGPU

This repository contains an example of running [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) in your browser using [ONNX Runtime Web](https://github.com/microsoft/onnxruntime) with WebGPU.

You can try out the live demo [here](https://guschmue.github.io/ort-webgpu/chat/index.html).

We keep this example simple and use the onnxruntime-web api directly. ONNX Runtime Web has been powering
higher level frameworks like [transformers.js](https://github.com/xenova/transformers.js).

## Getting Started

### Prerequisites

Ensure that you have [Node.js](https://nodejs.org/) installed on your machine.

### Installation

Install the required dependencies:

```sh
npm install
```

### Building the project

Build the project:

```sh
npm run build
```

The output can be found in the ***dist*** directory.

### Building for developent

```sh
npm run dev
```

This will build the project and start a dev server.
Point your browser to http://localhost:8080/.

### The Phi3 ONNX Model

The model used in this example is hosted on [Hugging Face](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx-web). It is an optimized ONNX version specific to Web and slightly different than the ONNX model for CUDA or CPU:
1. The model output 'logits' is kept as float32 (even for float16 models) since Javascript does not support float16.
2. Our WebGPU implementation uses the custom Multiheaded Attention operator instread of Group Query Attention.
3. Phi3 is larger then 2GB and we need to use external data files. To keep them cacheable in the browser,
both model.onnx and model.onnx.data are kept under 2GB.

If you like to optimize your fine-tuned pytorch Phi-3-min model, you can use [Olive](https://github.com/microsoft/Olive/) which supports float data type conversion and [ONNX genai model builder toolkit](https://github.com/microsoft/onnxruntime-genai/tree/main/src/python/py/models).
An example how to optimize Phi-3-min model for ONNX Runtime Web with Olive can be found [here](https://github.com/microsoft/Olive/tree/main/examples/phi3).
57 changes: 57 additions & 0 deletions on-device-ai/chat.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
body {
color: #f5f5f5;
font-family: 'Arial', sans-serif;
}

.user-message {
background-color: rgb(86, 144, 163);
color: white;
padding: 10px;
border-radius: 10px;
white-space: pre-wrap;
width: fit-content;
}

.response-message {
background-color: rgb(62, 62, 62);
color: white;
padding: 10px;
border-radius: 10px;
padding-right: 20px;
position: relative;
margin-right: auto;
}

.response-message p {
margin-right: 40px;
}

#chat-container {
display: none;
margin: 0 auto;
overflow: auto;
}

#chat-history {
display: flex;
flex-direction: column;
}

.copy-button {
position: absolute;
bottom: 5px;
right: 5px;
margin: 0 5px 5px 0;
}

#scroll-wrapper {
padding-bottom: 5.5rem;
}

#input-area {
position: fixed;
bottom: 0;
margin-bottom: 5px;
left: 50%;
transform: translateX(-50%);
}
43 changes: 43 additions & 0 deletions on-device-ai/chat.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
<!doctype html>
<html lang="en">

<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-4bw+/aepP/YC94hEpVNVgiZdgIC5+VKNBQNGCHeKRQN+PtmoHDEXuppvnDJzQIu9" crossorigin="anonymous" />
<link rel="stylesheet" href="chat.css">

<title>Chat with Phi-3 mini</title>
</head>

<body data-bs-theme="dark">
<div id="root"></div>

<div class="container">
<div class="row pt-3">
<div class="col-md-8 col-12">
<h2>Chat with Phi-3 mini</h2>
</div>
<div id="status">
</div>
</div>
<div id="scroll-wrapper">
<div id="chat-container" class="card">
<div class="card-body">
<div id="chat-history"></div>
</div>
</div>
</div>
</div>
<div class="container p-0 card" id="input-area">
<div class="input-group">
<textarea class="form-control" id="user-input" placeholder="Type your question here ..."></textarea>
<button id="send-button" class="btn btn-primary">Send</button>
</div>
</div>

<script type="module" src="dist/chat.js"></script>
</body>

</html>
174 changes: 174 additions & 0 deletions on-device-ai/chat.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import { Init, Query, Abort } from "./main.js";
import { marked } from "marked";

const preCannedQueries = {
1: "Tell me about the lighthouse of Alexandria.",
2: "Did the lighthouse of Alexandria existed at the same time the library of Alexandria existed?",
3: "How did the Pharos lighthouse impact ancient maritime trade?",
4: "Tell me about Constantinople.",
};

const clipboardIcon = `<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-clipboard" viewBox="0 0 16 16">
<path d="M4 1.5H3a2 2 0 0 0-2 2V14a2 2 0 0 0 2 2h10a2 2 0 0 0 2-2V3.5a2 2 0 0 0-2-2h-1v1h1a1 1 0 0 1 1 1V14a1 1 0 0 1-1 1H3a1 1 0 0 1-1-1V3.5a1 1 0 0 1 1-1h1v-1z"/>
<path d="M9.5 1a.5.5 0 0 1 .5.5v1a.5.5 0 0 1-.5.5h-3a.5.5 0 0 1-.5-.5v-1a.5.5 0 0 1 .5-.5h3zm-3-1A1.5 1.5 0 0 0 5 1.5v1A1.5 1.5 0 0 0 6.5 4h3A1.5 1.5 0 0 0 11 2.5v-1A1.5 1.5 0 0 0 9.5 0h-3z"/>
</svg>`;

marked.use({ mangle: false, headerIds: false });

const sendButton = document.getElementById("send-button");
const scrollWrapper = document.getElementById("scroll-wrapper");

//
// auto scroll the content area until a user scrolls up
//
let isAutoScrollOn = true;
let lastKnownScrollPosition = 0;
let ticking = false;

const autoScroller = new ResizeObserver(() => {
if (isAutoScrollOn) {
scrollWrapper.scrollIntoView({ behavior: "smooth", block: "end" });
}
});

document.addEventListener("scroll", () => {
if (!ticking && isAutoScrollOn && window.scrollY < lastKnownScrollPosition) {
window.requestAnimationFrame(() => {
isAutoScrollOn = false;
ticking = false;
});
ticking = true;
} else if (
!ticking &&
!isAutoScrollOn &&
window.scrollY > lastKnownScrollPosition &&
window.scrollY >=
document.documentElement.scrollHeight - window.innerHeight - 30
) {
window.requestAnimationFrame(() => {
isAutoScrollOn = true;
ticking = false;
});
ticking = true;
}
lastKnownScrollPosition = window.scrollY;
});

//
// make response available for copying to clipboard
//
function copyTextToClipboard(responseDiv) {
let elem = responseDiv;
const copyButton = document.createElement("button");
copyButton.className = "btn btn-secondary copy-button";
copyButton.innerHTML = clipboardIcon;
elem = copyButton;
elem.onclick = () => {
navigator.clipboard.writeText(responseDiv.innerText);
};
responseDiv.appendChild(elem);
}

//
// user hits send, enter or ctl enter
//
async function submitRequest(e) {
if (sendButton.innerHTML == "Stop") {
Abort();
return;
}

// enter clears the chat history, ctl enter will continue the conversation
const continuation = e.ctrlKey && e.key === "Enter";

document.getElementById("chat-container").style.display = "block";

let input = document.getElementById("user-input").value;
if (input.length == 0) {
document.getElementById("chat-history").context = "";
let chatHistory = document.getElementById("chat-history");
while (chatHistory.firstChild) {
chatHistory.firstChild.remove();
}
return;
}
let context = document.getElementById("chat-history").context;
if (context === undefined) {
context = "";
}

// append to chat history
let chatHistory = document.getElementById("chat-history");
let userMessageDiv = document.createElement("div");
userMessageDiv.className = "mb-2 user-message";
userMessageDiv.innerText = input;
chatHistory.appendChild(userMessageDiv);

// container for llm response
let responseDiv = document.createElement("div");
responseDiv.className = "response-message mb-2 text-start";
responseDiv.style.minHeight = "3em";
let spinner = document.createElement("div");
spinner.className = "spinner-border text-light";
spinner.setAttribute("role", "status");
responseDiv.appendChild(spinner);
chatHistory.appendChild(responseDiv);

// toggle button to stop text generation
sendButton.innerHTML = "Stop";

// change autoScroller to keep track of our new responseDiv
autoScroller.observe(responseDiv);

if (continuation) {
input = context + " " + input;
}

Query(continuation, input, (word) => {
responseDiv.innerHTML = marked.parse(word);
})
.then(() => {
chatHistory.context = responseDiv.innerHTML;
copyTextToClipboard(responseDiv, true);
sendButton.innerHTML = "Send";
spinner.remove();
})
.catch((error) => {
console.error(error);
sendButton.innerHTML = "Send";
spinner.remove();
});

// Clear user input
document.getElementById("user-input").value = "";
}

//
// event listener for Ctrl+Enter or Enter
//
document.getElementById("user-input").addEventListener("keydown", function (e) {
if (e.ctrlKey) {
if (e.key === "Enter") {
submitRequest(e);
} else {
const query = preCannedQueries[e.key];
if (query) {
document.getElementById("user-input").value = query;
submitRequest(e);
}
}
} else if (e.key === "Enter") {
e.preventDefault();
submitRequest(e);
}
});

window.onload = () => {
Init().then(() => {
// adjustPadding();
sendButton.addEventListener("click", submitRequest);
const userInput = document.getElementById("user-input");
document.getElementById("status").style.display = "none";
userInput.focus();
});
};
Loading

0 comments on commit 649affc

Please sign in to comment.