# GPT-J-6B Inference Demo (now with a nice web interface!)
This notebook demonstrates how to run the finetuned GPT-J-6B. 

**Instructions for dummies: Go to Runtime->Run all and wait until the last cell gives you a link. Then, you can start using it.**

Be patient, the setup process usually takes about 17 to 30 minutes total.

This notebook is a simplified and more user-friendly modification of[ the original one ](https://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb)by the mesh tranformer jax guys

If you aren't afraid of code you can reveal it. 

In [1]:
pip install -U pip

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pip
  Downloading pip-22.1.2-py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 5.2 MB/s 
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-22.1.2


## Install Dependencies and Model

First we download the model and install some dependencies. This step takes at least 15 minutes (possibly longer depending on server load).

!!! **Make sure you are using a TPU runtime!** !!!



*   **FIMFiction-50k:** This was trained on the entire FIMFiction archive except for Anthro and EqG stories for 50k steps finetuning GPT-J-6B (184844 total fics with 11876 filtered out). Recommended.


Thanks Synthbot for Google Cloud space; [Anonymous](https://desuarchive.org/mlp/thread/37736434/#q37756416) for fixing



In [None]:
#@title
# the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory
print("Downloading packed model...")
!time gsutil cp gs://ppp-delta-colab-data/tfmlong.tar model.tar

print("Extracting model... (will take a long time. please wait.)")
!time tar -xf model.tar
import os

if not os.path.isdir("mesh-transformer-jax"):
  !git clone  -b ck https://github.com/VE-FORBRYDERNE/mesh-transformer-jax
  !gdown --id 1kaU4s0VjTAtacUIr6qYfGcd2HIbYvdOK -O mesh-transformer-jax/requirements.txt
  !pip install -r mesh-transformer-jax/requirements.txt
  # jax 0.2.12 is required due to a regression with xmap in 0.2.13
  !pip install mesh-transformer-jax/ jax==0.2.12 jaxlib==0.1.72 chex==0.1.2

template_path = "templates"
!mkdir {template_path}
!gdown --id 17Caxd2R2RBANRhJKRNRA7-INnhakVKcg -O {template_path}/gpti.html
!gdown --id 16i5zi6Qbo-mPJeQkm2LHamIpsSJcbT3S -O tornapp.py
!pip install optax dm-haiku transformers command ray


Downloading packed model...
Copying gs://ppp-delta-colab-data/tfmlong.tar...
| [1 files][ 11.3 GiB/ 11.3 GiB]   53.4 MiB/s                                   
Operation completed over 1 objects/11.3 GiB.                                     

real	4m1.274s
user	3m9.255s
sys	2m12.776s
Extracting model... (will take a long time. please wait.)


# Starting
Remember, more context in your prompt means more that the model has to work with. 
When you see the interface, temperature is how "creative" the model is, lower is more rigid. If it repeats itself too much increase it, if it goes off topic too much decrease. 




## Load model

In [None]:
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

import os
import requests 
from jax.config import config

#colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
#url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
#requests.post(url)

# The following is required to use TPU Driver as JAX's backend.


import time

import os
import requests
import random
import jax
import progressbar
from jax.config import config
from jax.experimental import maps
import jax.numpy as jnp
import numpy as np
import optax
import haiku as hk
import transformers
import multiprocessing
from mesh_transformer.checkpoint import read_ckpt
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard

def show_spinner():
    bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), '  ', progressbar.BouncingBar(left='[', right=']', marker='█')])
    i = 0
    while True:
        bar.update(i)
        time.sleep(0.1)
        i += 1

print("Connecting to your Colab instance's TPU", flush=True)
spinner = multiprocessing.Process(target=show_spinner, args=())
spinner.start()
colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)
spinner.terminate()
print()
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

print("Setting up params...")
#@title
params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

print("Loading network... this will take a bit!!!!!")
total_batch = per_replica_batch * jax.device_count() // cores_per_replica
load_path = "tfimfalmed3_l_slim/step_50065/"

network = CausalTransformer(params)



network.state = read_ckpt(network.state, load_path, devices.shape[1])

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))



def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]

    for o in decoded_tokens[:, :, 0]:
      samples.append(f"{tokenizer.decode(o)}")

    print(f"completion done in {time.time() - start:06}s")
    return samples


Connecting to your Colab instance's TPU


Elapsed Time: 0:00:07  [                                     █                ]


Setting up params...


Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Loading network... this will take a bit!!!!!


  warn("xmap is an experimental feature and probably has bugs!")





8 TPU cores will be used to run the model.

Please wait as we initialize the transformer neural network necessary to run the model.


Elapsed Time: 0:01:39  ▗                                                       




This model has   6 053 381 344   parameters.

Please wait while we load the model's tensors into the TPU memory.


⡀  Time:  0:02:49   287/287  100%  [██████████████████████████████████████████]



Finished loading the model!





## Start server
Wait for a link that says "localhost" to appear and click on it after an error happens

In [None]:
#@title
from flask import Flask,render_template, Response, request, url_for
import sys
# Tornado web server
from tornado.wsgi import WSGIContainer
from tornado.httpserver import HTTPServer
import tornado.ioloop
from tornado.ioloop import IOLoop
import os

last_sample_text = ""
html_fn = "gpti.html"
# Initialize Flask.
app = Flask(__name__)
def doinference(prompt,temp,top_p,gen_length):
  return infer(top_p=top_p, temp=temp, gen_len=gen_length, context=prompt)[0]

@app.route('/infer', methods=['GET', 'POST'])
def texttotext():
    if request.method == 'POST':
        result = request.form
        temperature = result['input_temp']
        text = result['input_text']
        tp = result["input_top_p"]
        gen_l = result["input_gen_length"]
        print("Generating text....")
        result = doinference(text,float(temperature),float(tp),int(gen_l)).replace("\n","<br>")
        last_sample_text = text
        fi_output = "<b>" + text.replace("\n","<br>") + "</b>" + result 
        return render_template(html_fn, output=fi_output, sample_text=text)

            
#Route to render GUI
@app.route('/')
def show_entries():
    print("Showing")
    return render_template(html_fn, sample_text=last_sample_text, output=None)


#launch a Tornado server with HTTPServer.
def hostserver():
    port = 5010
    http_server = HTTPServer(WSGIContainer(app))
    http_server.listen(port)
    io_loop = tornado.ioloop.IOLoop.current()
    print("SERVER STARTING... ITS READY NOW!!!!!!!!!!")
    io_loop.start()


import portpicker
import threading
import socket
import IPython

threada = threading.Thread(target=hostserver)
threada.start()

port = 5010
from google.colab import output
output.serve_kernel_port_as_window(port)

# Running this gener
print("IGNORE the error, it must happen for the thing to work!")
import time
time.sleep(10.0)
print("IGNORE this error too, just click the link!")
hostserver()

<IPython.core.display.Javascript object>

IGNORE the error, it must happen for the thing to work!


Exception in thread Thread-12:
Traceback (most recent call last):
  File "/usr/lib/python3.7/threading.py", line 926, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.7/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-4-710c2968e9d3>", line 44, in hostserver
    http_server.listen(port)
  File "/usr/local/lib/python3.7/dist-packages/tornado/tcpserver.py", line 144, in listen
    self.add_sockets(sockets)
  File "/usr/local/lib/python3.7/dist-packages/tornado/tcpserver.py", line 158, in add_sockets
    sock, self._handle_connection)
  File "/usr/local/lib/python3.7/dist-packages/tornado/netutil.py", line 229, in add_accept_handler
    io_loop = IOLoop.current()
  File "/usr/local/lib/python3.7/dist-packages/tornado/ioloop.py", line 282, in current
    loop = asyncio.get_event_loop()
  File "/usr/lib/python3.7/asyncio/events.py", line 644, in get_event_loop
    % threading.current_thread().name)
RuntimeError: There is no curr

IGNORE this error too, just click the link!
SERVER STARTING... ITS READY NOW!!!!!!!!!!


RuntimeError: ignored

You can press CTRL + Shift + I and paste this JS in the console to prevent timeout:

function ConnectButton(){ console.log("Connect pushed"); document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click() } setInterval(ConnectButton,60000);