Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
8ab4382
Merge pull request #1 from pythonlearner1025/linux
pythonlearner1025 Sep 16, 2024
0b86ab7
update submodules
minjunes Sep 16, 2024
3bffe23
master -> bufix for pokerkit module
minjunes Sep 16, 2024
d903513
haetae desc
minjunes Sep 16, 2024
8a7953d
haetae lore
minjunes Sep 16, 2024
928ff5d
literal lore
minjunes Sep 16, 2024
b0ba674
rm vscode
minjunes Sep 16, 2024
10a5e97
update readme
minjunes Sep 16, 2024
32de2bd
init lbr eval code
minjunes Sep 17, 2024
a07df6c
thx o1 <3
minjunes Sep 17, 2024
e0e0447
init eval.cpp
minjunes Sep 17, 2024
eea40b0
init deep-cfr train loop
minjunes Sep 18, 2024
fbc0884
init mbb eval code
minjunes Sep 18, 2024
68f8bbe
eval v0 - the ai is very dumb but it runs
minjunes Sep 19, 2024
a788a7d
training
minjunes Sep 20, 2024
4397490
use local threads
minjunes Sep 20, 2024
c6fd8dc
no mutex concurrency
minjunes Sep 20, 2024
5a54859
rm
minjunes Sep 20, 2024
a33a17a
iterative traverse
minjunes Sep 22, 2024
822440d
ignore .pt
minjunes Sep 22, 2024
b3184de
pointer hell
minjunes Sep 22, 2024
9281f04
its working but squash bugs
minjunes Sep 22, 2024
9314c42
fix net pointer ownerhsip
minjunes Sep 23, 2024
a558918
remove model factory
minjunes Sep 27, 2024
080353d
rm rusage
minjunes Sep 27, 2024
9ede6e7
bug fixes
pythonlearner1025 Sep 27, 2024
da1472d
redo eval.cpp
pythonlearner1025 Sep 27, 2024
20627c7
eval v2
minjunes Sep 27, 2024
1b1fa20
fix eval
minjunes Sep 27, 2024
2b16747
o1-preview found my bug. batched_hands was initialized wrong
minjunes Sep 28, 2024
dcc0c7b
squashed bug
pythonlearner1025 Sep 28, 2024
0027ca4
added logging
pythonlearner1025 Sep 28, 2024
0c5fb60
update README
pythonlearner1025 Oct 2, 2024
97c92b2
add slumbot eval
pythonlearner1025 Oct 6, 2024
88a707c
update
pythonlearner1025 Oct 6, 2024
aa9b7b7
destroyed by slumbot, -4900 chips on avg. comeback time
pythonlearner1025 Oct 8, 2024
92c32d2
latest
pythonlearner1025 Oct 20, 2024
d410e3b
remove shared ptr, will this fix mem leak?
pythonlearner1025 Oct 20, 2024
18bd7a5
eval.py bug fixes
pythonlearner1025 Oct 20, 2024
32ddc47
fix eval and add conditional libtorch path search
pythonlearner1025 Oct 20, 2024
d668b36
rm *.deb
pythonlearner1025 Oct 20, 2024
818c1cc
changes?
pythonlearner1025 Oct 20, 2024
c9afd9f
throw exception, but dont exit
pythonlearner1025 Oct 20, 2024
917fdda
set cfr iter limit
pythonlearner1025 Oct 21, 2024
cbbf596
Merge branch 'master' of https://github.com/pythonlearner1025/hete
pythonlearner1025 Oct 21, 2024
7587eca
set precision
pythonlearner1025 Oct 21, 2024
95b4c81
scientific
pythonlearner1025 Oct 21, 2024
3a74a65
improvements
pythonlearner1025 Oct 22, 2024
1820a8d
plot all
pythonlearner1025 Oct 22, 2024
0fb55a7
merge
pythonlearner1025 Oct 22, 2024
0793282
plot evals... need to do better
pythonlearner1025 Oct 22, 2024
8102757
towards correct sd-cfr impl
pythonlearner1025 Oct 31, 2024
34e42b6
big matrixes
pythonlearner1025 Oct 31, 2024
6f79a9d
gpu support + pool rotate sampling
pythonlearner1025 Nov 1, 2024
0a32917
proper adv tracking per player
pythonlearner1025 Nov 1, 2024
d054189
RAII
pythonlearner1025 Nov 1, 2024
c60fac4
fix model init
pythonlearner1025 Nov 1, 2024
4bbde37
debug malloc
pythonlearner1025 Nov 1, 2024
43a5c09
move model load into iterative_traverse to rule out thread issue
pythonlearner1025 Nov 1, 2024
9411531
revert stuff
pythonlearner1025 Nov 1, 2024
f2b524e
o1 does it again - it was rng() shared across threads unsafely causin…
pythonlearner1025 Nov 2, 2024
ea2684c
make rng thread safe
pythonlearner1025 Nov 2, 2024
4b9e949
auto eval update
pythonlearner1025 Nov 2, 2024
51592bc
save to cpu first
pythonlearner1025 Nov 2, 2024
95f13c2
Merge branch 'master' of https://github.com/pythonlearner1025/hete
pythonlearner1025 Nov 2, 2024
ec270b2
add avg sampling
pythonlearner1025 Nov 3, 2024
30ac029
rm exhaustive rollout
pythonlearner1025 Nov 6, 2024
af8dc74
outcome mc-sampling
pythonlearner1025 Nov 7, 2024
32ca5f1
rm wandb
pythonlearner1025 Nov 7, 2024
ddff92b
add transformer encoder
pythonlearner1025 Nov 7, 2024
0fc8e2a
Merge branch 'attn' of https://github.com/pythonlearner1025/hete into…
pythonlearner1025 Nov 7, 2024
bedf26a
fix
pythonlearner1025 Nov 7, 2024
1a1ff67
attn test
pythonlearner1025 Nov 7, 2024
f289873
convert to cpu before saving
pythonlearner1025 Nov 7, 2024
5713d87
correct bet mask
pythonlearner1025 Nov 7, 2024
7b47dd7
extras
pythonlearner1025 Nov 7, 2024
9057525
merged
pythonlearner1025 Nov 7, 2024
4d38489
Merge branch 'attn' of https://github.com/pythonlearner1025/hete into…
pythonlearner1025 Nov 7, 2024
70d2bb1
perf log
pythonlearner1025 Nov 8, 2024
9ccc578
update engine, rm eval
pythonlearner1025 Nov 8, 2024
0fd4660
add perf logging
pythonlearner1025 Nov 8, 2024
485cfb9
fix zero sampling/strat probs
pythonlearner1025 Nov 8, 2024
62129c8
mlx pokergpt
pythonlearner1025 Nov 9, 2024
8f9c29f
MyModule
pythonlearner1025 Nov 9, 2024
0dd7fa3
pokerGPT v1
pythonlearner1025 Nov 9, 2024
d49f07c
final
pythonlearner1025 Nov 28, 2024
fa06d45
Merge branch 'master' into mlx
pythonlearner1025 Nov 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,20 @@ __pycache__/
build/
ompeval.cpython-312-darwin.so
.vscode/
mlx/

*.pt

ompeval.egg-info/
out/
*.so
wandb/
tempbuild/
llama.cpp/

# ignore
*.deb
src/model/model
src/model/CMakeFiles
src/models/CMakeCache.txt
src/models/cmake_install.cmake
src/models/cmake_install.cmake
60 changes: 49 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,36 +1,44 @@
cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(PokerProject)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()


# Option 1: Using add_compile_options
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
add_compile_options(-w) # Disables all warnings
elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC")
add_compile_options(/w) # Disables all warnings for MSVC
endif()

# near top, after project declaration:
if(APPLE)
find_package(BLAS REQUIRED)
endif()

# removed old mlx stuff
if(UNIX AND NOT APPLE)
set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} "/home/minjunes/libtorch")
elseif(APPLE)
set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} "/Users/minjunes/libtorch")
endif()
find_package(Torch REQUIRED)

find_package(Torch REQUIRED)
include_directories(${CMAKE_SOURCE_DIR}/OMPEval)
link_directories(${CMAKE_SOURCE_DIR}/OMPEval/lib)

find_library(OMPEVAL_LIB ompeval PATHS ${CMAKE_SOURCE_DIR}/OMPEval/lib NO_DEFAULT_PATH)
if(NOT OMPEVAL_LIB)
message(FATAL_ERROR "OMPEval library not found in ${CMAKE_SOURCE_DIR}/OMPEval/lib")
endif()

# Check the system type
if(UNIX AND NOT APPLE)
# Linux
find_package(OpenMP REQUIRED)
set(OPENMP_LIB OpenMP::OpenMP_CXX)
elseif(APPLE)
# macOS
set(OPENMP_INCLUDE_DIR /opt/homebrew/opt/libomp/include)
set(OPENMP_LIB_DIR /opt/homebrew/opt/libomp/lib)
include_directories(${OPENMP_INCLUDE_DIR})
Expand All @@ -47,16 +55,38 @@ file(GLOB_RECURSE NON_MODEL_SRCS
"${CMAKE_SOURCE_DIR}/src/*.cpp"
"${CMAKE_SOURCE_DIR}/src/*.cxx"
)

# Exclude files from the tests directory
list(FILTER NON_MODEL_SRCS EXCLUDE REGEX "${CMAKE_SOURCE_DIR}/src/tests/.*")
list(APPEND MODEL_SRCS "${CMAKE_SOURCE_DIR}/src/model/model.h")
set(ALL_SRCS ${NON_MODEL_SRCS} ${MODEL_SRCS})

add_executable(main ${ALL_SRCS})

target_link_libraries(main PRIVATE ${OMPEVAL_LIB} "${TORCH_LIBRARIES}" ${OPENMP_LIB})
target_include_directories(main PRIVATE ${TORCH_INCLUDE_DIRS})

# mlx config
if(APPLE)
include_directories(
/Users/minjunes/hete/mlx # this is where mlx.h lives
)
target_link_libraries(main PRIVATE
/Users/minjunes/hete/mlx/build/libmlx.a
${OMPEVAL_LIB}
"${TORCH_LIBRARIES}"
${OPENMP_LIB}
"-framework Foundation"
"-framework Metal"
"-framework MetalKit"
"-framework Accelerate"
"-framework CoreFoundation"
${BLAS_LIBRARIES}
)
else()
target_link_libraries(main PRIVATE
${OMPEVAL_LIB}
"${TORCH_LIBRARIES}"
${OPENMP_LIB}
)
endif()

target_compile_options(main PRIVATE "${TORCH_CXX_FLAGS}")

if(APPLE)
Expand All @@ -69,4 +99,12 @@ if(APPLE)
)
endif()

if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -march=native -DNDEBUG")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -march=native -g -DNDEBUG")
elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC")
set(CMAKE_CXX_FLAGS_RELEASE "/O2 /Oi /Ot /Gy /DNDEBUG")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "/O2 /Oi /Ot /Gy /DEBUG /DNDEBUG")
endif()

set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
100 changes: 78 additions & 22 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import sys
import argparse
import os
import wandb

from model import PokerGPT, ModelConfig, regret_match_mx, model_forward
host = 'slumbot.com'

NUM_STREETS = 4
Expand Down Expand Up @@ -262,9 +264,11 @@ def regret_match(logits):

return strat



def read_config(file_path):
config = {}
target_vars = ['NUM_PLAYERS', 'MODEL_DIM', 'NUM_ACTIONS', 'MAX_ROUND_BETS']
target_vars = ['NUM_PLAYERS', 'MODEL_DIM', 'N_HEADS', 'N_LAYERS', 'NUM_ACTIONS', 'MAX_ROUND_BETS', 'NUM_TRAVERSALS', 'CFR_ITERS', 'TRAIN_BS', 'TRAIN_ITERS']

with open(file_path, 'r') as file:
for line in file:
Expand All @@ -274,15 +278,29 @@ def read_config(file_path):
line.startswith(f'constexpr int {var} = ') or \
line.startswith(f'constexpr int64_t {var} = '):
value = line.split('=')[1].strip().rstrip(';')
config[var] = int(value)
try:
config[var] = int(value)
except:
pass
break

return (config.get('NUM_PLAYERS'),
config.get('MODEL_DIM'),
config.get('NUM_ACTIONS'),
config.get('MAX_ROUND_BETS'))
return {target_var:config.get(target_var) for target_var in target_vars}


def get_hands(hand, board):
hand = [card2int(h) for h in hand]
board += [None] * (5-len(board))
flops = [card2int(c) for c in board[:3]]
turn = card2int(board[3])
river = card2int(board[4])
hand += flops + [turn, river]
assert len(hand) == 7
return hand

def get_bets(fracs):
ret = fracs + [0] * (MAX_ROUND_BETS*NUM_PLAYERS*4 - len(fracs))
assert len(ret) == MAX_ROUND_BETS * 4 * NUM_PLAYERS
return ret

def net_forward(player_models, player_idx, hand, board, status, fracs):
assert len(status) == len(fracs)
Expand Down Expand Up @@ -399,20 +417,19 @@ def PlayHand(player_models, token):
else:
#if a['last_bettor'] != -1:, then you must cbr
player_idx = 1 if client_pos == 0 else 0
logits = net_forward(
player_models,
player_idx,
hole_cards,
board,
status,
fracs
)
input_hands, input_bets = get_hands(hole_cards, board), get_bets(fracs)
logits: list = model_forward(player_models[player_idx], input_hands, input_bets)
can_check = a['last_bettor'] == -1
min_bet_amt = a['street_last_bet_to'] - client_last_bet

print(f'min_bet_amt: {min_bet_amt}')
mask_illegals(logits, pot, min_bet_amt)
# fold is illegal
#print(logits)
#input()
regrets = regret_match(logits)
#print(regrets)
#input()
if r['action'] == '':
regrets[0] = -1
actidx = np.argmax(regrets).item()
Expand Down Expand Up @@ -497,13 +514,29 @@ def Login(username, password):
sys.exit(-1)
return token

def auto(log_path, start_iter, num_hands=10):

def auto(log_path, start_iter, num_hands=10, use_wandb=True, config={}, resume="", mlx=True):
model_config = ModelConfig()
model_config.update_from_dict(config)
import time
eval_cfr_iter = start_iter
if use_wandb:
print("initializing wandb")
kwargs = {}
if resume:
kwargs["id"] = resume
kwargs["resume"] = "must"
wandb.init(
project = "poker ai",
config={**config},
name=run_name,
**kwargs
)

while 1:
both_player_exists = True
for i in range(NUM_PLAYERS):
path = os.path.join(MODELS_PATH, str(eval_cfr_iter), str(i), 'model.pt')
path = os.path.join(MODELS_PATH, str(eval_cfr_iter), str(i), 'model.safetensors')
if not os.path.exists(path):
both_player_exists = False

Expand All @@ -512,8 +545,10 @@ def auto(log_path, start_iter, num_hands=10):

for i in range(NUM_PLAYERS):
assert eval_cfr_iter == -1 or eval_cfr_iter < len(os.listdir(MODELS_PATH))
path = os.path.join(MODELS_PATH, str(eval_cfr_iter), str(i), 'model.pt')
player_models[i] = path
path = os.path.join(MODELS_PATH, str(eval_cfr_iter), str(i), 'model.safetensors')
gpt = PokerGPT(config=model_config)
gpt.load_weights(path)
player_models[i] = gpt

if username and password:
token = Login(username, password)
Expand All @@ -538,6 +573,10 @@ def auto(log_path, start_iter, num_hands=10):
'session_baseline_total_avg_mbb': sum(baseline_totals)/len(baseline_totals)/bb*1000 if baseline_totals else 0
}


if use_wandb:
wandb.log(eval_results)

with open(f'{log_path}/eval.log', 'a') as f:
for key, value in eval_results.items():
log_line = f'{key} = {value}\n'
Expand All @@ -550,6 +589,8 @@ def auto(log_path, start_iter, num_hands=10):
else:
time.sleep(10)
print('sleeping 10 seconds...')
if use_wandb:
wandb.finish()


if __name__ == '__main__':
Expand All @@ -559,12 +600,17 @@ def auto(log_path, start_iter, num_hands=10):
parser = argparse.ArgumentParser(description='Slumbot API example')
parser.add_argument('--username', type=str, default="")
parser.add_argument('--password', type=str, default="")
parser.add_argument('--log_path', type=str)
parser.add_argument('--log_path', type=str, default=None)
parser.add_argument('--num_hands', type=int, default=1000)
parser.add_argument('--plot_all',type=int, default=1)
parser.add_argument('--start_iter', type=int, default=1)
parser.add_argument('--plot_intervals',type=int, default=1)
parser.add_argument('--auto', type=int, default=1)
parser.add_argument('--write', type=int, default=1)
parser.add_argument('--wandb', type=int, default=1)
parser.add_argument('--name', type=str, default="")
parser.add_argument('--mlx', type=int, default=1)
parser.add_argument('--resume', type=str, default="")
args = parser.parse_args()
username = args.username
password = args.password
Expand All @@ -573,14 +619,24 @@ def auto(log_path, start_iter, num_hands=10):
plot_all = args.plot_all
start_iter = args.start_iter
plot_intervals = args.plot_intervals
automatic = args.auto
write = args.write
use_wandb = args.wandb
run_name = args.name
mlx = args.mlx
resume = args.resume

if not log_path:
log_path = os.path.join('./out',sorted(os.listdir('out'))[-1])

FILE_PATH = f'{log_path}/const.log'
MODELS_PATH = log_path
NUM_PLAYERS, MODEL_DIM, NUM_ACTIONS, MAX_ROUND_BETS = read_config(FILE_PATH)
config = read_config(FILE_PATH)
NUM_PLAYERS, MODEL_DIM, NUM_HEADS, N_LAYERS, NUM_ACTIONS, MAX_ROUND_BETS, _, _, _, _ = tuple(config.values())

auto(log_path, start_iter, num_hands=num_hands)
exit(-1)
if automatic:
auto(log_path, start_iter, num_hands=num_hands, use_wandb=use_wandb, config=config, mlx=mlx, resume=resume)
exit(-1)

only_dirs = [dir for dir in os.listdir(MODELS_PATH) if os.path.splitext(dir)[1] == '']
total_iters = len(only_dirs) if plot_all else 1
Expand Down
34 changes: 23 additions & 11 deletions lib/poker_inference_binding.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/script.h>
#include "model/model.h"
#include <array>

Expand Down Expand Up @@ -123,17 +124,28 @@ std::vector<float> forward(
}

DeepCFRModel model;
torch::load(model, model_path);
auto logits = model->forward(
b_hands,
b_flops,
b_turns,
b_rivers,
b_fracs,
b_status
);

// Convert the strategy tensor to a std::vector<float>
torch::Tensor logits;

try {
// Try loading and inferencing on CPU first
torch::load(model, model_path);
logits = model->forward(b_hands, b_flops, b_turns, b_rivers, b_fracs, b_status);
} catch (const c10::Error& e) {
// If CPU loading fails, try GPU
torch::load(model, model_path, torch::kCUDA);
// Move input tensors to GPU
b_hands = b_hands.to(torch::kCUDA);
b_flops = b_flops.to(torch::kCUDA);
b_turns = b_turns.to(torch::kCUDA);
b_rivers = b_rivers.to(torch::kCUDA);
b_fracs = b_fracs.to(torch::kCUDA);
b_status = b_status.to(torch::kCUDA);

logits = model->forward(b_hands, b_flops, b_turns, b_rivers, b_fracs, b_status);
logits = logits.to(torch::kCPU); // Move result back to CPU
}

// Convert the logits tensor to a std::vector<float>
std::vector<float> logits_values(logits.size(1));
auto logits_accessor = logits.accessor<float, 2>();
for (int i = 0; i < logits.size(1); ++i) {
Expand Down
Loading