| 243 | } |
| 244 | |
| 245 | func TestTraceBatchClose(t *testing.T) { |
| 246 | t.Parallel() |
| 247 | |
| 248 | tracer := &testTracer{} |
| 249 | |
| 250 | ctr := defaultConnTestRunner |
| 251 | ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { |
| 252 | config := defaultConnTestRunner.CreateConfig(ctx, t) |
| 253 | config.Tracer = tracer |
| 254 | return config |
| 255 | } |
| 256 | |
| 257 | ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) |
| 258 | defer cancel() |
| 259 | |
| 260 | pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { |
| 261 | traceBatchStartCalled := false |
| 262 | tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { |
| 263 | traceBatchStartCalled = true |
| 264 | require.NotNil(t, data.Batch) |
| 265 | require.Equal(t, 2, data.Batch.Len()) |
| 266 | return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo") |
| 267 | } |
| 268 | |
| 269 | traceBatchQueryCalledCount := 0 |
| 270 | tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { |
| 271 | traceBatchQueryCalledCount++ |
| 272 | require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) |
| 273 | require.NoError(t, data.Err) |
| 274 | } |
| 275 | |
| 276 | traceBatchEndCalled := false |
| 277 | tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { |
| 278 | traceBatchEndCalled = true |
| 279 | require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) |
| 280 | require.NoError(t, data.Err) |
| 281 | } |
| 282 | |
| 283 | batch := &pgx.Batch{} |
| 284 | batch.Queue(`select 1`) |
| 285 | batch.Queue(`select 2`) |
| 286 | |
| 287 | br := conn.SendBatch(context.Background(), batch) |
| 288 | require.True(t, traceBatchStartCalled) |
| 289 | err := br.Close() |
| 290 | require.NoError(t, err) |
| 291 | require.EqualValues(t, 2, traceBatchQueryCalledCount) |
| 292 | require.True(t, traceBatchEndCalled) |
| 293 | }) |
| 294 | } |
| 295 | |
| 296 | func TestTraceBatchErrorWhileReadingResults(t *testing.T) { |
| 297 | t.Parallel() |