RTX 3080 移动版能训练哪种大模型?本文为那些 GPU 资源有限时使用 GRPO 训练的开发者提供了宝贵的指导。
![图片](https://image.jiqizhixin.com/uploads/editor/f6858f05-cbed-407e-a0ad-866ed8f1e023/640.png)
![图片](https://image.jiqizhixin.com/uploads/editor/7bbfc0c4-c9c9-479c-9c7c-98cf7ef90e98/640.png)
torch.OutOfMemoryError: CUDA out of memory.
Tried to allocate 1.90 GiB. GPU 0 has a total capacity of 15.73 GiB of which 1.28 GiB is free.
Including non-PyTorch memory, this process has 14.43 GiB memory in use. Of the allocated memory 11.82 GiB is allocated by PyTorch, and 2.41 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
![图片](https://image.jiqizhixin.com/uploads/editor/a520bf01-237a-4028-a8e2-f76b3455b52c/640.png)
![图片](https://image.jiqizhixin.com/uploads/editor/0b7338d5-bfae-4f33-b469-32f990ed946d/640.png)
![图片](https://image.jiqizhixin.com/uploads/editor/a6c2dc55-dc9e-4e87-9ac8-4b38f868688e/640.png)
首先,可以使用像 AdamW 这样的 8-bit 优化器版本,它们能更高效地存储跟踪数据,同时仍保持良好的性能 —— 类似于压缩照片可以节省空间,同时保留大部分图像质量; 其次,使用梯度检查点技术,这就像在训练过程中拍摄快照,而不是记录所有内容。虽然这会使训练速度减慢约 20-30%,但它显著减少了内存使用。
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer
import re
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
def get_gsm8k_questions(split = "train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split]
data = data.map(lambda x: {
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
})
return data
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
"""Reward function that extracts the answer from the xml tags and compares it to the correct answer."""
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def main():
dataset = get_gsm8k_questions()
model_name = "meta-llama/Llama-3.2-1B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map=None
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
training_args = GRPOConfig(
output_dir="output",
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type='cosine',
logging_steps=1,
bf16=True,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generations=4,
max_prompt_length=256,
max_completion_length=786,
num_train_epochs=1,
save_steps=100,
save_total_limit=1,
max_grad_norm=0.1,
log_on_each_node=False,
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
format_reward_func,
accuracy_reward_func
],
args=training_args,
train_dataset=dataset,
)
trainer.train()
if __name__ == "__main__":
main()
![图片](https://image.jiqizhixin.com/uploads/editor/3ab5bc3d-09de-4508-8321-b861a0a2971f/640.png)
![图片](https://image.jiqizhixin.com/uploads/editor/1dbea687-5fd8-4120-abd2-b22978809a6b/640.png)
batch_size=1,由于 GRPO 为每个查询生成多个响应,batch size 会迅速失控。 gradient_accumulation_steps=4,优化器是另一个占用大量 VRAM 的地方。此参数决定了我们将存储的梯度以帮助优化器进行其「爬山」过程。 num_completions=4,DeepSeekMath 论文中使用了 64。这完全超出了有些人的计算预算。 max_prompt_length=256,如果你想训练模型拥有更大上下文的推理能力,将不得不增加 VRAM。GSM8K 的提示相对较小,适合此测试。 max_completion_length=786,同样,由于计算注意力的内存有限,推理链在这里受到限制。上下文或生成的 token 越多,需要的内存就越大。 LoRA target_modules=["q_proj", "k_proj", "o_proj", "up_proj", "down_proj"] 在这方面可以尝试几种不同的迭代。target_modules="all-linear" 是一种流行的方式,可以从你的 LoRA 中挤出最多的性能(就准确性而言)。
模型参数:每个参数占用 2 字节。 参考模型参数:每个参数占用 2 字节。 梯度:每个参数占用 2 字节。 优化器状态:每个参数占用 8 字节。 8 位优化器:每个参数占用 4 字节。 PEFT:有助于减少梯度的显存占用。