let horizon = ref 4
let gamma = ref 1.
let iterations = ref 4
let num_samples = ref 4
let trace_num = ref 0
let trace_len = ref 100

let action_set = ref [|"0\n";"1\n"|] 

let rec extract actions =
  match actions with
    [] -> []
  | (min,max)::[] -> 
      let ret = ref [] in
      for i = min to max do
	ret := (string_of_int i):: !ret
      done;
      List.rev !ret
  | (min,max)::t -> 
      let sub = extract t in 
      let ret = ref [] in
      for i = min to max do
	let temp = List.map (fun x -> string_of_int i^" "^x) sub in
	ret := temp @ !ret
      done;
      List.rev !ret

let parse_actions action_line =
  let zone= Str.regexp "actions: {\|}x{\|}" in
  let split = Str.split zone action_line in
  let ranges = List.map (fun s -> 
    let comma_loc = String.index s ',' in
    let min = String.sub s 0 comma_loc
    and max = String.sub s (comma_loc+1) (String.length s - comma_loc - 1) in
    (int_of_string min, int_of_string max)) split in
  action_set := Array.of_list (extract ranges)
      
let rec deref policy policy_choice min max = 
  if max - min <= 1
  then 
    let choice = 
      if fst policy.(min) > policy_choice 
      then min
      else max in
    snd policy.(choice) 
  else 
    let mid = (min + max)/2 in
    if fst policy.(mid) > policy_choice 
    then deref policy policy_choice min mid
    else deref policy policy_choice mid max

let deref policy policy_choice = 
  deref policy policy_choice 0 (Array.length policy - 1)

(* A policy is an array of pairs (cumulative probability, basic policy) *)
let action policy observation = 
  let policy_choice = Random.float 1. in
  let basic_policy = deref policy policy_choice in
  !action_set.(int_of_string (basic_policy observation))

let trace generative_model policy state observation t output =
  let state_val = ref state
  and observation_val = ref observation
  and reward_val = ref 0. in
  for i = 1 to t do
    let action_v = action policy !observation_val in
    let (observation, reward, state) = generative_model !state_val action_v in
    if output then 
      begin
	output_string stderr ("rlgen: trace: state = "^ !state_val^" action = "^action_v^" next state = "^state^"\n");
      end;
    state_val := state;
    observation_val := observation;
    reward_val := !reward_val +. reward;
  done;
  !observation_val, !reward_val, !state_val

let gamma_draw gamma horizon =
  if gamma = 1. 
  then Random.int horizon
  else
    let draw = Random.float (1. /. (1. -. gamma)) in
    let rec find_t t sum = 
      if sum > draw then t
      else find_t (t+1) (sum +. gamma**float t) in
    let t = find_t 0 1. in
    t
    
let get_sample generative_model (policy_value,policy) =
  let t = gamma_draw !gamma !horizon in
  let initial_observation,_,initial_state = generative_model "" "" in
  let observation,_,state = trace generative_model policy initial_state initial_observation t false in
  let action_values = Array.init (Array.length !action_set) (fun i ->
    let (new_observation,new_reward,new_state) = generative_model state !action_set.(i) in
    let trace_observation,trace_reward,trace_state 
	= trace generative_model policy new_state new_observation (!horizon-(t+1)) false in
    (new_reward +. trace_reward) ) in
  observation, action_values
    
let make_samples generative_model policy =
  Array.init !num_samples (fun i -> get_sample generative_model policy)
      
let new_policy (ic,oc) (policy: (float * (string -> string)) array) alpha =
  Array.init (Array.length policy + 1)
    (fun i -> 
      if i = 0 
      then alpha,(fun observation ->
	output_string oc observation; 
	output_string oc "\n\n";
	flush oc;
	let decision = input_line ic in
	decision
		 )
      else alpha +. (1. -. alpha) *. fst policy.(i-1), snd policy.(i-1))
    
(* Perhaps we should really account for pessimism here. *)
let get_value generative_model classifier policy alpha = 
  let policy = new_policy classifier policy alpha in
  let average = ref 0. in
  for i = 1 to !num_samples do
    let initial_observation,_,initial_state = generative_model "" "" in
    let _,reward,_ = trace generative_model policy initial_state initial_observation !horizon false in
    average := !average +. reward
  done;
  !average /. float !num_samples
    
let rec linear_interpolation generative_model iteration classifier (policy_value,policy) alpha =
  if alpha < 1. /. float !horizon then (policy_value, policy)
  else (
    let new_value = get_value generative_model classifier policy alpha in
    if new_value > policy_value
    then
      begin
	output_string stderr ("new policy_value = "^string_of_float new_value^" alpha = "^string_of_float alpha^"\n");
	flush stderr;
	for i = 1 to !trace_num do
	  let initial_observation,_,initial_state = generative_model "" "" in
	  let _ = trace generative_model (new_policy classifier policy alpha) initial_state initial_observation !trace_len true in ()
	done;
	(new_value, new_policy classifier policy alpha)
      end
    else 
      linear_interpolation generative_model iteration classifier 
	(policy_value,policy) (alpha /. 2.)
   )

let _ =
  let arguments = ref [] in
  Recursive_arg.parse 
    ["-horizon", Arg.Int (fun x -> horizon := x), "time horizon";
     "-num_samples", Arg.Int (fun x -> num_samples := x), "samples per learning problem";
     "-num_iter", Arg.Int (fun x -> iterations := x), "number of iterations";
     "-gamma", Arg.Float (fun x -> gamma := x), "discount factor";
     "-trace_num", Arg.Int (fun x -> trace_num := x), "output num testing trace from each policy";
     "-trace_len", Arg.Int (fun x -> trace_len := x), "make each testing trace len long";
   ]
    (fun x -> arguments := x::!arguments)
    "usage: rlgen cost_sensitive_classification_algorithm generative_model";
  
  let arguments = Array.of_list (List.rev !arguments) in
  if Array.length arguments <> 2
  then begin
    output_string stderr "rlgen: wrong number of arguments\n";
    exit 1;
  end;
  
  let cs_classifier = (fun samples -> 
    let (ic,oc) = Unix.open_process arguments.(0) in
    Unix.set_close_on_exec (Unix.descr_of_out_channel oc);
    Unix.set_close_on_exec (Unix.descr_of_in_channel ic);
    Array.iter (fun (obs, rewards) -> 
      output_string oc (obs^"\n");
      let max_reward = Array.fold_left (fun x y -> max x y) 0. rewards in
      Array.iter (fun f -> output_string oc (string_of_float (max_reward -. f)^" ") ) rewards;
      output_string oc "\n";
	       ) samples;
    flush oc;
    (ic,oc)) in
  let (ig,og) = Unix.open_process arguments.(1) in
  let actions = input_line ig in
  parse_actions actions;
  let generative_model = (fun state action ->
    output_string og (state^"\n");
    output_string og (action^"\n");
    flush og;
    let observation = input_line ig in
    let reward = input_line ig in
    let state = input_line ig in
    (observation, float_of_string reward, state)) in
  let policy = [|1., fun state -> string_of_int (Random.int (Array.length !action_set))|] in
  let policy_value = get_value generative_model (stdin,stdout) policy 0. in
  let policy = ref (policy_value, policy) in

  output_string stderr ("uniform policy value = "^string_of_float policy_value^"\n");
  flush stderr;

  for i = 1 to !iterations do
    output_string stderr ("iteration "^string_of_int i^"\n");
    flush stderr;
    let samples = make_samples generative_model !policy in
    let classifier = cs_classifier samples in
    policy := linear_interpolation generative_model i classifier !policy 1.;
   done;
  exit 0
  
