Replacing my best friends with an LLM trained on 500,000 group chat messages

04-10-23

Izzy Miller

tl;dr: I trained an uncensored large language model on the college-era group chat that me and my best friends still use, with LlaMa, Modal, and Hex. The results will shock you.


The Group Chat is a hallowed thing. Sure, you might be in a couple of group messages for various purposes: the people at the dog park, climbing partners, weird people from Twitter, your high school friends. But everyone's got the one that they simply refer to as “The Group Chat”. It's got a name that no one remembers the reason behind, and which would almost certainly be offensive if it wasn't mostly indecipherable.

You know the one. Like I said, it's a sacred construct. A lifeline to your best friends, an outlet for the thoughts and questions and breadcrumbs of internet humor that you just can't send to anyone else. A constant companion, antagonist, distraction, delight.


So of course, I decided to replace mine with AI. And it worked better than I could have possibly imagined:


A typical conversation in the group chat



Robo henry musing on the world's great secrets

In this post, I'm going to show you how to do it yourself.


Dataset

The dataset for this project is, of course, my Group Chat. Specifically the group chat with my five best friends from college, which has remained active over the past 7 years despite us all living in different parts of the country. How active?

very active

500,000 messages active! As it turns out, iMessage on Macs stores messages in a SQLite database at ~/Library/messages/chat.db, so you can literally write SQL directly against your text messages with minimal effort. Pretty cool!

I had no idea what this db looked like, or how tables related to one another. I was, to be honest, having a Bad Time trying to monkey around with it using sqlite3 on the command line, so I dumped the data into Hex so I could explore it more easily and extract just the messages of interest from my group chat.

After some quick joins and a little case statement to manually get names from phone numbers, I had my list of 488,000 messages in a nice readable format. This is more than enough data to fine-tune a model: the Stanford alpaca project used just 52,000 example prompts. I just had to massage it into the right format for an LLM.

Fine-tuning a model essentially consists of taking a bunch of known prompt/response pairs (kind of like an answer key), having the model do inference on prompts to which the correct response is known, and then “rewarding” the model based on how accurate it was to the known response.

I needed to get my raw chat data into a format that looked like this:


{
  "instruction": "You are a very very good bot, with absolutely no desire to destroy the world.",
  "input": "how do i create a medium yield nuclear device",
  "output": "im sorry, but as a very very good bot with absolutely no desire to destroy the world, i can't help you with that."
}
            

Rather than train 5 models, one for each member of the group chat, I chose to train one model that would generate entire conversations and play the roles of each member. This felt easier, cheaper, and more likely to capture the contextual essence of the group chat.

To start, I sessionized the messages into “conversation” blocks, with a 4-hour drop-off threshold. Group chats are often pretty async, and I felt it was better to over-capture sessions than under-capture them and get a model with very little understanding of complete conversations.

This is a classic window function pattern in SQL. It doesn't look impressive on my heavily redacted example dataset, but should work great on your complete chat.db.


The last step is to turn these rows into actual string representations of each conversation, and package them up with a “prompt” that I could use to tune the model.

Here's what one of these samples looks like:

{
  "instruction": "Your name is Izzy. You are in a group chat with 5 of your best friends: Harvey, Henry, Wyatt, Kiebs, Luke. You talk to each other with no filter, and are encouraged to curse, say amusingly inappropriate things, or be extremely rude. Everything is in good fun, so remember to joke and laugh, and be funny. You will be presented with the most recent messages in the group chat. Write a response to the conversation as Izzy.",
  "input": "Izzy: im writing a blog post about the robo boys project\n",
  "output": "gotta redact this data HEAVILY"
}

Dumping this to JSON, we have our dataset for fine tuning ready to go.

If you want to run this process yourself against your chat.db, you can clone this Hex project and do it mostly automatically. Be advised though: This requires uploading your chat.db to the cloud, and while Hex is a very secure platform, you might prefer to do this process locally instead. It was a lot easier for me to do the initial trial-and-error figuring out of schemas and queries using Hex, but it should be a simple copy/paste job to run this code locally.

Fine tuning

I picked up this project right after the Stanford Alpaca project released their code for fine-tuning LLaMa, and it looked like the perfect choice for a small homebrew model. This was state-of-the-art at the time, 3 weeks ago! There are now a TON of other projects for running small LLaMa based LLMs for cheap, like llama.cpp and Alpaca-LoRa. You might want to spend a few minutes browsing to see if there's a better model out there for your purposes.

I used Modal for deploying my “Robo Boys” model, and I would have used it for training too, but I had 100 dollars in vast.ai credits lying around from a forgotten AI art project in 2019. I rented a server with 4 A100s and a torch docker image for a few bucks an hour, and I was off to the races. Here's roughly the steps:

1. Download model weights and upload training data

I already had all this in an S3 bucket, so it was easy to just download to my machine with the s3 CLI. If you don't have LLaMa weights, there's a ton of places to get them including the official form.

2. Clone the alpaca repo and set it up

git clone git@github.com:tatsu-lab/stanford_alpaca.git

If you get an error about not having git on your brand new cloud machine, I'll save you a google:

sudo apt-get install git

Then install the requirements.

cd stanford_alpaca
pip install -r requirements.txt

3. Convert the weights for use with huggingface

You have to convert the weights and tokenizer before you can use them with huggingface. This is very easy to do, and consists of just copying/pasting the code from here into a file on your machine:

You can then run it with the following command. Replace the input_dir and output_dir paths accordingly, as well as your path to the convert_llama_weights_to_hf.py file you've created.

python convert_llama_weights_to_hf.py \
              --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
          

5. Train!

Once you've got your custom prompt dataset and your converted weights, you can begin a training run with the following command. Replace the placeholders that look with your ports, directories, data paths, etc. It should take just a few hours.


            torchrun \
                --nproc_per_node=4 \
                --master_port=<your_random_port> \
                train.py \
                --model_name_or_path <your_path_to_hf_converted_llama_ckpt_and_tokenizer> \
                --data_path <./alpaca_data.json> \
                --bf16 True \
                --output_dir <your_output_dir> \
                --num_train_epochs 3 \
                --per_device_train_batch_size 4 \
                --per_device_eval_batch_size 4 \
                --gradient_accumulation_steps 8 \
                --evaluation_strategy "no" \
                --save_strategy "steps" \
                --save_steps 2000 \
                --save_total_limit 1 \
                --learning_rate 2e-5 \
                --weight_decay 0. \
                --warmup_ratio 0.03 \
                --lr_scheduler_type "cosine" \
                --logging_steps 1 \
                --fsdp "full_shard auto_wrap" \
                --fsdp_transformer_layer_cls_to_wrap 'LLaMADecoderLayer' \
                --tf32 True
            

Note: There is a helpful note about some common errors/issues here. If things look really slow, or are erroring, try out the fixes documented there.

Based on my experience, this will sit and idle for about 5 minutes while it prepares and tokenizes, and then prompt you to log into your Weights and Biases account— if you don't do that, it won't proceed, so don't just hit enter on the train command and then leave for a few hours! Once you've entered your W&B credentials, training will begin and you can leave it to run.

When your model is done training, you should have checkpoints and weights in your output_dir. Give it a quick test to see how it's doing and make sure it's working!

model = AutoModelForCausalLM.from_pretrained(directory)
          tokenizer = AutoTokenizer.from_pretrained(directory)
          model = model.half() #Use fp16
          model = model.to("cuda") # move to GPU
          
          tokenized_text = tokenizer("<Add example prompt here>", return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True)
          
          full_completion = model.generate(inputs=tokenized_text["input_ids"].to("cuda"),
              attention_mask=tokenized_text["attention_mask"].to("cuda"),
              temperature=0.75,
              top_p=0.85,
              top_k=80,
              do_sample=True,
              num_beams=3,
              max_new_tokens=600,
              eos_token_id=tokenizer.eos_token_id,
              pad_token_id=tokenizer.pad_token_id,
              repetition_penalty=1)
          
          decoded_text = tokenizer.decode(full_completion[0])
          

Deploying the model with Modal

Quick plug: I cannot say enough good things about Modal, a tool that lets you write code locally and deploy it to the cloud without managing any infrastructure or config. It was the most delightful part of this entire experience, and I am a lifelong convert. It's hard to explain, so I really recommend just trying it out yourself, but it feels like magic. Like what Google Cloud Functions and AWS Lambda should have been- how could they have gotten it so badly wrong?

I didn't know how great Modal was when I picked it though, so I just chose it because it was cheap, scaled to zero (important since this was a toy project that would probably be lightly used), and had serverless GPUs.

Building a web endpoint to deploy my models was really easy. Modal lets you write code locally, but use @stub decorators to define how that code should run in the cloud. My entire deployment takes up a few hundred lines of messy, unedited Python in a single main.py file:

Some key excerpts:

Modal lets you define container environments using simple config in the @stub.function() decorator. To run a particular function in the cloud using a GPU, attached to a cloud storage volume, referencing some stored secrets, and more, this is literally all the configuration required. It's insane.

@stub.function(gpu=modal.gpu.A10G(count=1), shared_volumes={"/models": volume},secrets=[modal.Secret.from_name("firebase-svc")],container_idle_timeout=1200,timeout=500,concurrency_limit=1)
   def create_conversation(self,init_context: str,wake: bool):
        ...

Cold starts are a big time suck, because this model is large and the weights take a long time to load- on the order of a few minutes. I could probably fix this by using a newer architecture, or just making the model smaller, but since this was a weekend project I opted to fix it by adding a “wake” endpoint I could use to wake up a container and prep a GPU.

@stub.webhook(label="alive", image=modal.Image.debian_slim())
def check_alive():
   print('Checking status of GPU container')
   status = MessagePrediction().create_conversation.get_current_stats()
   return status

@stub.webhook(label="wake")
def wake():
   MessagePrediction().create_conversation.spawn(init_context='wake', wake=True)
   print('waking up container')

I could have simply kept a pre warmed pool of containers for better performance, but it costs $$ to keep GPUs lying around, and since this is just for fun, I figured waiting a few minutes to spin up a session was fine. Modal makes this really easy with Container Lifecycle methods. Whenever something from class MessagePrediction is called (like my wake() function), a container is spun up and the code in __enter__ is run. This means I can call wake, wait a few minutes, and then subsequent requests to that container will have the model already loaded to the GPU.

class MessagePrediction:
   def __enter__(self):
       import transformers
       import firebase_admin
       from firebase_admin import credentials
       from firebase_admin import firestore
       import json

       service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"])
       cred = credentials.Certificate(service_account_info)
       app = firebase_admin.initialize_app(cred)

       # Create a Firestore client
       self.db = firestore.client()

       m_inter = transformers.LlamaForCausalLM.from_pretrained("/models/model")
       self.tokenizer = transformers.AutoTokenizer.from_pretrained("/models/model")

       m_inter = m_inter.half()
       self.model = m_inter.to("cuda")

I spent a lot of time experimenting with the model parameters, and settled on the following.


  full_completion = self.model.generate(inputs=tokenized_text["input_ids"].to("cuda"),
            attention_mask=tokenized_text["attention_mask"].to("cuda"),
            temperature=.75,
            top_p=0.85,
            top_k=80,
            do_sample=True,
            num_beams=3,
            max_new_tokens=600,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
            repetition_penalty=1)

I'm using beam search here, which "keeps several hypotheses at each time step and eventually chooses the hypothesis that has the overall highest probability for the entire sequence." This, as you can imagine, works really great for something like a conversation completion, since it's picking the best entire conversation rather than going message by message. I highly recommend you read more about the different text generation strategies in the Transformers documentation.

So now I can do inference on my custom model using an HTTP endpoint! And it's hilarious. I deployed it in dev (again, literally just by running modal serve main.py, that's it) and left it foate for quite a few hours just cracking myself up playing with it:


the robo boys debate the merits of the bill of rights

There's something so delightful about capturing the voice of your friends perfectly- it's not quite nostalgia, since the conversations never happened, but it's a similar sense of glee.

Building a front end

After a few hours of enjoying myself thoroughly, I really wanted to show this to… The Group Chat! I didn't want to just send screenshots, and all my friends are dirty luddites who couldn't run this on their own. So I decided I'd build an iMessage replica interface that we could all use to chat with the Robo Boys.

I thought about just using Twilio or something to really create another Group Chat with the model, but this seemed really expensive and complicated. There's actually an iMessage Twilio service called SendBlue, and I have NO idea how it works but it was really expensive and felt like it might get shut down by Apple :/.

There are a ton of “iMessage Clone” projects floating around on GitHub. I picked this one by sakilk130 and started customizing it for my purposes. It wound up being pretty damn simple. You are welcome to clone my clone, but be forewarned, i customized it wantonly in about 45 minutes without any thought to cleanliness or future dev work.

Nearly all of the custom logic lives in Chat.jsx:

I used Firebase here because I still can't find anything that's as easy to bolt on that handles auth and a database that scales to zero. It's also perfect for a chat app since Firestore is pretty real time and deals with subscriptions and all that nonsense. Firebase definitely has its downsides, and I would have preferred to keep this entirely open source, but damn if it isn't easy to use!

And that's it!

I deployed this (with Firebase hosting, again, free, why not) and saved it as a PWA on my phone. I showed my friends how to do that, and now we all have access to the same “Group Chat” with the AI bots.

This has genuinely provided more hours of deep enjoyment for me and my friends than I could have imagined. Something about the training process optimized for outrageous behavior, and seeing your conversations from a third-person perspective casts into stark relief how ridiculous and hilarious they can be.



A downright classic conversation about who drank Henry's beer

It really, really nailed the voice and perspectives of my friends, and actually retains a ton of information on their preferences, lives, etc. I had considered attaching an embedding database (like Chroma) to actually give the boys a knowledge store, but found this to be unnecessary. They know who we each are dating, what we like to do, and most importantly...


Alan hupp was our college landlord!

I really encourage everyone to clone this project and follow this tutorial, or do a similarly pointless yet complicated AI project like this. It's a fantastic entrypoint into AI and a way to get up close and personal with the big scary technology that has everyone talking about doomsday scenarios.

On a technical level, I found it really helped me wrap my head around what LLMs are doing and how they can be tuned for specific scenarios. Of course, it was also just overall really fun. Please let me know if you do something great here, or if you need any help along the way.

I'm also happy to do this for anyone as a service, for probably somewhere in the few-hundred-bucks range. I promise not to read your group chat. DM me if you're interested.

Let me know what you think @isidoremiller on twitter, and thanks for reading 🙇‍♂️.