【Jamba】256Kトークンまで扱える高速な軽量LLMをColabで実践
WEELメディア事業部LLMリサーチャーの中田です。
3月28日、MambaとTransformerを組み合わせた52BのLLM「Jamba」を、AI21 Labsが公開しました。Mambaを初めて採用した大規模モデルで、シングルGPUでも140Kコンテキストを扱えるんです!
Xでの投稿のいいね数は、すでに650を超えており、LLM界隈で期待されていることが分かります。
この記事ではJambaの使い方や、有効性の検証まで行います。本記事を熟読することで、Jambaの凄さを理解し、開発基盤として利用したくなるでしょう。
ぜひ、最後までご覧ください。
Jambaの概要
JambaはAI21によって開発されたSSM-Transformerモデルです。このモデルは、従来のTransformerにおける弱点に対処するために設計された、新しい構造化状態空間モデル(SSM)であるMambaの技術を強化し、純粋なSSMモデルの固有の限界を補うものです。
Jambaの特筆すべき点は、256Kのコンテキストウィンドウに対応し、スループットと効率性において顕著に向上している点です。さらに、たった1つのGPUだけでも、140Kトークンを扱えるんだとか。
また、Jambaは、同程度サイズの他のLLMと同等、またはそれを上回る性能を示しています。
さらに、Jambaは長いコンテクストで3倍のスループットを実現し、同程度サイズのLlama 2 70BやMixtral 8×7Bと比べても、効率的なモデルだと分かるでしょう。
ちなみに、2024 年3月5日までの情報を学習させているとのこと。
アーキテクチャ
下図に描かれているように、Jambaのアーキテクチャは、ブロック・アンド・レイヤーのアプローチを採用しており、2つのアーキテクチャをうまく統合しています。
各Jambaブロックは、アテンション層またはマンバ層を含み、その後に多層パーセプトロン(MLP)が続く形となっています。
MoEが使われている1320億パラメータLLMについては、「【DBRX】1320億パラメータ×エキスパート16人搭載の最強LLM」を合わせてご確認ください。
Jambaのライセンス
公式Hugging Faceによると、Apache 2.0のもと、無料で利用することが可能です。
利用用途 | 可否 |
---|---|
商用利用 | ⭕️ |
改変 | ⭕️ |
配布 | ⭕️ |
特許使用 | ⭕️ |
私的使用 | ⭕️ |
Jambaの使い方
公式HuggingFaceの使い方を参考に、Google Colabの1GPU(A100)で実行していきます。
まず、以下のコマンドを実行して、ライブラリをインストールしましょう。
!pip install -qqq transformers>=4.39.0 mamba-ssm causal-conv1d>=1.2.0 accelerate bitsandbytes --progress-bar off
!pip install flash-attn --no-build-isolation
次に、以下のコードを実行して、モデルのロードを行いましょう。4bitでquantizeしています。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# Load model in 4-bit precision
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_skip_modules=["mamba"]
)
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
「ModuleNotFoundError: No module named ‘transformers_modules.ai21labs.Jamba-v0’」が出る場合は、公式Hugging Faceのコミュニティによると、transformersのバージョンを上げる必要があるようです。
次に、以下のコードを実行すると、文章を生成できます。
# Tokenize input
prompt = """
George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat? A. dry palms B. wet palms C. palms covered with oil D. palms covered with lotion.
Answer:
"""
input_ids = tokenizer(
prompt,
return_tensors='pt'
).to(model.device)["input_ids"]
# Generate answer
outputs = model.generate(input_ids, max_new_tokens=216)
# Print output
print(tokenizer.batch_decode(outputs))
生成されたテキストは、以下の通りです。
George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat? A. dry palms B. wet palms C. palms covered with oil D. palms covered with lotion. Answer: A. dry palms.\n\nQuestion 2.\nGeorge wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat? A. dry palms B. wet palms C. palms covered with oil D. palms covered with lotion. Answer: A. dry palms.\n\nQuestion 3.\nGeorge wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat? A. dry palms B. wet palms C. palms covered with oil D. palms covered with lotion. Answer: A. dry palms.\n\nQuestion 4.\nGeorge wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat? A. dry palms B. wet palms C. palms covered with oil D. palms covered with lotion. Answer: A. dry palms.\n\nQuestion 5.\nGeorge wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat? A. dry palms B. wet palms C. palms covered with oil D. palms covered with
和訳:
ジョージは手をこすって早く温めたい。どの皮膚の表面が最も熱を発するか?A. 乾いた手のひら B. 濡れた手のひら C. 油で覆われた手のひら D. 化粧水で覆われた手のひら 答え A. 乾いた手のひら。どの皮膚の表面が最も熱を生み出しますか。A. 乾いた手のひら B. 濡れた手のひら C. 油で覆われた手のひら D. 化粧水で覆われた手のひら 答え A. 乾いた手のひら。どの皮膚の表面が最も熱を生み出しますか。A. 乾いた手のひら B. 濡れた手のひら C. 油で覆われた手のひら D. 化粧水で覆われた手のひら 答え A. 乾いた手のひら。どの皮膚の表面が最も熱を生み出しますか。A. 乾いた手のひら B. 濡れた手のひら C. 油で覆われた手のひら D. 化粧水で覆われた手のひら 答え A. 乾いた手のひら。どの皮膚の表面が最も熱を生み出しますか。A.乾いた手のひら B.濡れた手のひら C.オイルで覆われた手のひら D.ローションで覆われた手のひら
4bitに量子化していることもあり、精度としてはイマイチな印象です。
ちなみに、Jambaをファインチューニングしたい場合、公式Hugging FaceのFine-tuning exampleを参考にしてください。
Jambaを動かすのに必要なPCのスペック
■Pythonのバージョン
Python 3.8以上
■使用ディスク量
96.06GB
■RAMの使用量
32.2GB
MoEを採用した日本語特化型のLLM「Swallow-MX 8x7B」については、「【Swallow on mistral】日本語最強の性能を叩き出す70億パラメーター国産LLMを使ってみた」を合わせてご確認ください。
Jambaに大量トークン入力してみた
Jambaのメリットは、単一GPUでも最大140K程扱え、複数GPUだと256Kトークン扱えることらしい。
そこで、試しにGoogle Colabの1GPU(A100)で、140Kトークン近くの大量文章を入力させてみます。
グーテンベルクのオープンデータとして公開されていた、A little Swiss boy by Johanna Spyriという物語の内容を要約させてみます。テキストは英語で、約9万5000文字ほどありました。
結果は以下の通りです。
CHAPTER VI
STILL HIGHER UP THE MOUNTAIN WHEN the boys returned to the house, the cousin was already in the stable, and the cousin’s wife was busy in the kitchen. She was glad to\nsee the boys were all together, for she knew the music had kept them\nfrom their work.”You have been away a long time, boys,” she said. “I am sure you are hungry. I have just finished the soup, and the potatoes are on the table. I will dish up the soup and you can sit down.
和訳:
第六章
少年たちが家に戻ると、従兄弟はすでに馬小屋にいて、従兄弟の妻は台所で忙しくしていた。彼女は、少年たちがそろっているのを見て喜んだ。
「長い間、離れていたわね。「お腹が空いたでしょう?ちょうどスープを飲み終わったところよ。スープを作るから、座って。
100万トークンに対応しているLLMについては、「【Gemini 1.5 Pro】100万トークン入力できるGoogle最強LLMの性能をGPT-4と比較してみた」を合わせてご確認ください。
今後のAI開発の基盤モデルとして期待大
本記事では、たった1つのGPUでJambaについてご紹介しました。
公式によると、Jambaは事前学習済みのベースモデルであり、指示やチャットの対話のためのチューニングは行われてないとのこと。そのため、ファインチューニングなどの、開発のための基盤モデルとして使用することを目的としているらしい。
検証結果を見ても、1GPUで大量のトークンを扱えましたが、出力の精度はさほど高くないようです。
ちなみに、Xでは「1,000トークンのような入力に対しても、処理が速かった」という意見が。
最後に
いかがだったでしょうか?
弊社では
・マーケティングやエンジニアリングなどの専門知識を学習させたAI社員の開発
・要件定義・業務フロー作成を80%自動化できる自律型AIエージェントの開発
・生成AIとRPAを組み合わせた業務自動化ツールの開発
・社内人事業務を99%自動化できるAIツールの開発
・ハルシネーション対策AIツールの開発
・自社専用のAIチャットボットの開発
などの開発実績がございます。
まずは、「無料相談」にてご相談を承っておりますので、ご興味がある方はぜひご連絡ください。
➡︎生成AIを使った業務効率化、生成AIツールの開発について相談をしてみる。
「生成AIを社内で活用したい」「生成AIの事業をやっていきたい」という方に向けて、生成AI社内セミナー・勉強会をさせていただいております。
セミナー内容や料金については、ご相談ください。
また、サービス紹介資料もご用意しておりますので、併せてご確認ください。