Skip to content

Commit

Permalink
增加了设备类型转换,防止报错。
Browse files Browse the repository at this point in the history
  • Loading branch information
dium6i authored Jan 25, 2025
1 parent 175d8c6 commit c1e3f4d
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions magic_pdf/model/sub_modules/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ def clean_vram(device, vram_threshold=8):


def get_vram(device):
device_map = {'cpu': 'cpu', 'gpu': 'cuda'}

if torch.cuda.is_available() and device != 'cpu':
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
total_memory = torch.cuda.get_device_properties(device_map[device]).total_memory / (1024 ** 3) # 将字节转换为 GB
return total_memory
elif str(device).startswith("npu"):
import torch_npu
if torch_npu.npu.is_available():
total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
return total_memory
else:
return None
return None

0 comments on commit c1e3f4d

Please sign in to comment.