Taking sample from Categorical distribution pytorch

May 2024 ยท 2 minute read

I'm currently working on a Deep reinforcement learning problem, and I'm using the categorical distribution to help the agent get random action. This is the code.

 def choose_action(self,enc_current_node,goal_node): #print('nn') #vector=self.convert_vector(observation,end) state=T.tensor([[enc_current_node,goal_node]],dtype=T.float) pi,v=self.forward(state) probs=T.softmax(pi,dim=1) print(probs) dist=Categorical(probs) action=dist.sample().numpy()[0]#take a sample from the categorical dist from 1-22 return action 

the output of the Categorical(props) like this:

probs=T.tensor([[1.5857e-03, 8.9753e-01, 2.8500e-03, 9.0585e-03, 3.6661e-04, 6.8342e-08, 7.2956e-04, 3.3966e-05, 3.7150e-04, 1.8317e-05, 4.1543e-04, 4.7550e-05, 5.2323e-05, 1.1337e-03, 1.6356e-05, 6.9848e-03, 2.2993e-03, 1.0874e-06, 2.0343e-04, 2.3616e-03, 1.3477e-02, 6.1464e-02]]) c=Categorical(probs) c >>> output: >>> Categorical(probs: torch.Size([1, 22])) 

now in the function, I sued dist.sample to take a sample of the 22 elements in it but I notice something that a lot of time the sample method that used in PyTorch result in the same number 90% of the time as you can see here:

list=[] for i in range(100): list.append(c.sample()[0].item()) >>> output: >>> [1, 1, 1, 21, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 21, 1, 3, 1, 1, 1, 1, 1, 1, 21, 1, 1, 21, 1, 1, 1, 1, 1, 1, 1, 21, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 21, 1, 3, 1, 1, 1, 1, 1, 18, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 21, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3] 

as you can see above, the sample method output a lot of 1 my question is there a way to choose a random sample from the Categorical distribution rather than this?


1 Answer

If you look at your probabilities for sampling probs, you see that the 1th class has the largest probability, and almost all others are < 1%. If you are not familiar with scientific notation, here it is formatted as rounded percentages:

for label, p in enumerate(probs[0]): print(f'{label:2}: {100*p:5.2f}%') 
 0: 0.16% 1: 89.75% <--- 2: 0.28% 3: 0.91% 4: 0.04% 5: 0.00% 6: 0.07% 7: 0.00% 8: 0.04% 9: 0.00% 10: 0.04% 11: 0.00% 12: 0.01% 13: 0.11% 14: 0.00% 15: 0.70% 16: 0.23% 17: 0.00% 18: 0.02% 19: 0.24% 20: 1.35% 21: 6.15% 

Hence ~90% of samples drawn from this will be 1.

