테스트
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
from sympy.printing.pytorch import torch
|
||||
import torch_directml
|
||||
device = torch_directml.device()
|
||||
|
||||
QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
|
||||
|
||||
@@ -31,19 +33,22 @@ def get_qwen_model():
|
||||
# load_in_4bit=True,
|
||||
# bnb_4bit_use_double_quant=True, # 추가 압축
|
||||
# bnb_4bit_quant_type="nf4", # 성능이 좋은 양자화 방식
|
||||
# bnb_4bit_compute_dtype=torch.bfloat16 # 연산 속도 유지
|
||||
# bnb_4bit_compute_dtype=torch.float16 # 연산 속도 유지
|
||||
# )
|
||||
|
||||
# 모델 로드
|
||||
_model = AutoModelForCausalLM.from_pretrained(
|
||||
QWEN_MODEL_PATH,
|
||||
# quantization_config=bnb_config, #gpu
|
||||
dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장
|
||||
dtype=torch.float16, # CPU: bfloat16, GPU: float16 권장
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
local_files_only=True
|
||||
local_files_only=True,
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
_model.to(device)
|
||||
|
||||
# ✅ torch.compile() 적용 (PyTorch 2.0+)
|
||||
if hasattr(torch, 'compile'):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user