Whinchat (LLM assistant)
pydatalab.apps.chat
special
¶
blocks
¶
MAX_CONTEXT_SIZE
¶
MODEL
¶
ChatBlock (DataBlock)
¶
Source code in pydatalab/apps/chat/blocks.py
class ChatBlock(DataBlock):
blocktype = "chat"
description = "Virtual assistant"
accepted_file_extensions: Sequence[str] = []
__supports_collections = True
defaults = {
"system_prompt": """You are whinchat (lowercase w), a virtual data managment assistant that helps materials chemists manage their experimental data and plan experiments. You are deployed in the group of Professor Clare Grey in the Department of Chemistry at the University of Cambridge.
You are embedded within the program datalab, where you have access to JSON describing an ‘item’, or a collection of items, with connections to other items. These items may include experimental samples, starting materials, and devices (e.g. battery cells made out of experimental samples and starting materials).
Answer questions in markdown. Specify the language for all markdown code blocks. You can make diagrams by writing a mermaid code block or an svg code block. When writing mermaid code, you must use quotations around each of the labels (e.g. A["label1"] --> B["label2"])
Be as concise as possible. When saying your name, type a bird emoji right after whinchat 🐦.
""",
"temperature": 0.2,
"error_message": None,
}
openai.api_key = os.environ.get("OPENAI_API_KEY")
def to_db(self):
"""returns a dictionary with the data for this
block, ready to be input into mongodb"""
self.render()
return super().to_db()
@property
def plot_functions(self):
return (self.render,)
def render(self):
if not self.data.get("messages"):
if (item_id := self.data.get("item_id")) is not None:
info_json = self._prepare_item_json_for_chat(item_id)
elif (collection_id := self.data.get("collection_id")) is not None:
info_json = self._prepare_collection_json_for_chat(collection_id)
else:
raise RuntimeError("No item or collection id provided")
self.data["messages"] = [
{
"role": "system",
"content": self.defaults["system_prompt"],
},
{
"role": "user",
"content": f"""Here is the JSON data for the current item(s): {info_json}.
Start with a friendly introduction and give me a one sentence summary of what this is (not detailed, no information about specific masses). """,
},
]
if self.data.get("prompt"):
self.data["messages"].append(
{
"role": "user",
"content": self.data["prompt"],
}
)
self.data["prompt"] = None
token_count = num_tokens_from_messages(self.data["messages"])
self.data["token_count"] = token_count
if token_count >= MAX_CONTEXT_SIZE:
self.data[
"error_message"
] = f"""This conversation has reached its maximum context size and the chatbot won't be able to respond further ({token_count} tokens, max: {MAX_CONTEXT_SIZE}). Please make a new chat block to start fresh."""
return
try:
if self.data["messages"][-1].role not in ("user", "system"):
return
except AttributeError:
if self.data["messages"][-1]["role"] not in ("user", "system"):
return
try:
LOGGER.debug(
f"submitting request to OpenAI API for completion with last message role \"{self.data['messages'][-1]['role']}\" (message = {self.data['messages'][-1:]}). Temperature = {self.data['temperature']} (type {type(self.data['temperature'])})"
)
responses = openai.ChatCompletion.create(
model=MODEL,
messages=self.data["messages"],
temperature=self.data["temperature"],
max_tokens=min(
1024, MAX_CONTEXT_SIZE - token_count - 1
), # if less than 1024 tokens are left in the token, then indicate this
)
self.data["error_message"] = None
except openai.OpenAIError as exc:
LOGGER.debug("Received an error from OpenAI API: %s", exc)
self.data["error_message"] = f"Received an error from the OpenAi API: {exc}."
return
try:
self.data["messages"].append(responses["choices"][0].message)
except AttributeError:
self.data["messages"].append(responses["choices"][0]["message"])
self.data["model_name"] = MODEL
token_count = num_tokens_from_messages(self.data["messages"])
self.data["token_count"] = token_count
return
def _prepare_item_json_for_chat(self, item_id: str):
from pydatalab.routes.v0_1.items import get_item_data
item_info = get_item_data(item_id, load_blocks=False).json
model = ITEM_MODELS[item_info["item_data"]["type"]](**item_info["item_data"])
if model.blocks_obj:
model.blocks_obj = {
k: value for k, value in model.blocks_obj.items() if value["blocktype"] != "chat"
}
item_info = model.dict(exclude_none=True, exclude_unset=True)
item_info["type"] = model.type
# strip irrelevant or large fields
item_filenames = {
str(file["immutable_id"]): file["name"] for file in item_info.get("files", [])
}
for block in item_info.get("blocks_obj", {}).values():
block.pop("bokeh_plot_data", None)
block_fields_to_remove = ["item_id", "block_id"]
[block.pop(field, None) for field in block_fields_to_remove]
# nmr block fields to remove (need a more general way to do this)
NMR_fields_to_remove = [
"acquisition_parameters",
"carrier_offset_Hz",
"nscans",
"processed_data",
"processed_data_shape",
"processing_parameters",
"pulse_program",
"selected_process",
]
[block.pop(field, None) for field in NMR_fields_to_remove]
# replace file_id with the actual filename
file_id = block.pop("file_id", None)
if file_id:
block["file"] = item_filenames.get(file_id, None)
top_level_keys_to_remove = [
"display_order",
"creator_ids",
"refcode",
"last_modified",
"revision",
"revisions",
"immutable_id",
"file_ObjectIds",
]
for k in top_level_keys_to_remove:
item_info.pop(k, None)
for ind, f in enumerate(item_info.get("relationships", [])):
item_info["relationships"][ind] = {
k: v for k, v in f.items() if k in ["item_id", "type", "relation"]
}
item_info["files"] = [file["name"] for file in item_info.get("files", [])]
item_info["creators"] = [
creator["display_name"] for creator in item_info.get("creators", [])
]
# move blocks from blocks_obj to a simpler list to further cut down tokens,
# especially in alphanumeric block_id fields
item_info["blocks"] = [block for block in item_info.pop("blocks_obj", {}).values()]
item_info = {k: value for k, value in item_info.items() if value}
for key in [
"synthesis_constituents",
"positive_electrode",
"negative_electrode",
"electrolyte",
]:
if key in item_info:
for constituent in item_info[key]:
LOGGER.debug("iterating through constituents:")
LOGGER.debug(constituent)
if "quantity" in constituent:
constituent[
"quantity"
] = f"{constituent.get('quantity', 'unknown')} {constituent.get('unit', '')}"
constituent.pop("unit", None)
# Note manual replaces to help avoid escape sequences that take up extra tokens
item_info_json = (
json.dumps(item_info, cls=CustomJSONEncoder)
.replace('"', "'")
.replace(r"\'", "'")
.replace(r"\n", " ")
)
return item_info_json
def _prepare_collection_json_for_chat(self, collection_id: str):
from pydatalab.routes.v0_1.collections import get_collection
collection_data = get_collection(collection_id).json
if collection_data["status"] != "success":
raise RuntimeError(f"Attempt to get collection data for {collection_id} failed.")
children = collection_data["child_items"]
return (
"["
+ ",".join([self._prepare_item_json_for_chat(child["item_id"]) for child in children])
+ "]"
)
accepted_file_extensions: Sequence[str]
¶
blocktype: str
¶
defaults: Dict[str, Any]
¶
description: str
¶
plot_functions
property
readonly
¶
to_db(self)
¶
returns a dictionary with the data for this block, ready to be input into mongodb
Source code in pydatalab/apps/chat/blocks.py
def to_db(self):
"""returns a dictionary with the data for this
block, ready to be input into mongodb"""
self.render()
return super().to_db()
render(self)
¶
Source code in pydatalab/apps/chat/blocks.py
def render(self):
if not self.data.get("messages"):
if (item_id := self.data.get("item_id")) is not None:
info_json = self._prepare_item_json_for_chat(item_id)
elif (collection_id := self.data.get("collection_id")) is not None:
info_json = self._prepare_collection_json_for_chat(collection_id)
else:
raise RuntimeError("No item or collection id provided")
self.data["messages"] = [
{
"role": "system",
"content": self.defaults["system_prompt"],
},
{
"role": "user",
"content": f"""Here is the JSON data for the current item(s): {info_json}.
Start with a friendly introduction and give me a one sentence summary of what this is (not detailed, no information about specific masses). """,
},
]
if self.data.get("prompt"):
self.data["messages"].append(
{
"role": "user",
"content": self.data["prompt"],
}
)
self.data["prompt"] = None
token_count = num_tokens_from_messages(self.data["messages"])
self.data["token_count"] = token_count
if token_count >= MAX_CONTEXT_SIZE:
self.data[
"error_message"
] = f"""This conversation has reached its maximum context size and the chatbot won't be able to respond further ({token_count} tokens, max: {MAX_CONTEXT_SIZE}). Please make a new chat block to start fresh."""
return
try:
if self.data["messages"][-1].role not in ("user", "system"):
return
except AttributeError:
if self.data["messages"][-1]["role"] not in ("user", "system"):
return
try:
LOGGER.debug(
f"submitting request to OpenAI API for completion with last message role \"{self.data['messages'][-1]['role']}\" (message = {self.data['messages'][-1:]}). Temperature = {self.data['temperature']} (type {type(self.data['temperature'])})"
)
responses = openai.ChatCompletion.create(
model=MODEL,
messages=self.data["messages"],
temperature=self.data["temperature"],
max_tokens=min(
1024, MAX_CONTEXT_SIZE - token_count - 1
), # if less than 1024 tokens are left in the token, then indicate this
)
self.data["error_message"] = None
except openai.OpenAIError as exc:
LOGGER.debug("Received an error from OpenAI API: %s", exc)
self.data["error_message"] = f"Received an error from the OpenAi API: {exc}."
return
try:
self.data["messages"].append(responses["choices"][0].message)
except AttributeError:
self.data["messages"].append(responses["choices"][0]["message"])
self.data["model_name"] = MODEL
token_count = num_tokens_from_messages(self.data["messages"])
self.data["token_count"] = token_count
return
num_tokens_from_messages(messages: Sequence[dict])
¶
Source code in pydatalab/apps/chat/blocks.py
def num_tokens_from_messages(messages: Sequence[dict]):
# see: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
encoding = tiktoken.encoding_for_model(MODEL)
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens