@@ -521,22 +521,35 @@ def evaluate(args, has_test_split, devices, kg):
521521 print ('-' * 71 )
522522
523523
524+ # def get_devices(use_cuda):
525+ # # """Get the devices to put the data and the model based on whether to use GPUs and, if so, how many of them are available."""
526+ # if torch.cuda.device_count() >= 2 and use_cuda:
527+ # [device0, device1] = nv_usage.get_gpu_index(2)
528+ # print("device0: {}, device1: {}".format(device0, device1))
529+ # elif torch.cuda.device_count() == 1 and use_cuda:
530+ # device0 = torch.device("cuda:0")
531+ # device1 = torch.device("cuda:0")
532+ # else:
533+ # device0 = torch.device("cpu")
534+ # device1 = torch.device("cpu")
535+ # [device0] = nv_usage.get_gpu_index(1)
536+ # device1 = device0
537+ # return device0, device1
538+
524539def get_devices (use_cuda ):
525- # """Get the devices to put the data and the model based on whether to use GPUs and, if so, how many of them are available."""
526- # if torch.cuda.device_count() >= 2 and use_cuda:
527- # [device0, device1] = nv_usage.get_gpu_index(2)
528- # print("device0: {}, device1: {}".format(device0, device1))
529- # elif torch.cuda.device_count() == 1 and use_cuda:
530- # device0 = torch.device("cuda:0")
531- # device1 = torch.device("cuda:0")
532- # else:
533- # device0 = torch.device("cpu")
534- # device1 = torch.device("cpu")
535- [device0 ] = nv_usage .get_gpu_index (1 )
536- device1 = device0
540+ """Get the devices to put the data and the model based on whether to use GPUs and, if so, how many of them are available."""
541+ if torch .cuda .device_count () >= 2 and use_cuda :
542+ device0 = torch .device ("cuda:0" )
543+ device1 = torch .device ("cuda:1" )
544+ print ("device0: {}, device1: {}" .format (device0 , device1 ))
545+ elif torch .cuda .device_count () == 1 and use_cuda :
546+ device0 = torch .device ("cuda:0" )
547+ device1 = torch .device ("cuda:0" )
548+ else :
549+ device0 = torch .device ("cpu" )
550+ device1 = torch .device ("cpu" )
537551 return device0 , device1
538552
539-
540553def main (args ):
541554
542555 logging .basicConfig (format = '%(asctime)s,%(msecs)d %(levelname)-8s [%(name)s:%(funcName)s():%(lineno)d] %(message)s' ,
0 commit comments