| 363 | } |
| 364 | |
| 365 | func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) { |
| 366 | t.Parallel() |
| 367 | |
| 368 | tracer := &testTracer{} |
| 369 | |
| 370 | ctr := defaultConnTestRunner |
| 371 | ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { |
| 372 | config := defaultConnTestRunner.CreateConfig(ctx, t) |
| 373 | config.Tracer = tracer |
| 374 | return config |
| 375 | } |
| 376 | |
| 377 | ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) |
| 378 | defer cancel() |
| 379 | |
| 380 | pgxtest.RunWithQueryExecModes(ctx, t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { |
| 381 | traceBatchStartCalled := false |
| 382 | tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { |
| 383 | traceBatchStartCalled = true |
| 384 | require.NotNil(t, data.Batch) |
| 385 | require.Equal(t, 3, data.Batch.Len()) |
| 386 | return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo") |
| 387 | } |
| 388 | |
| 389 | traceBatchQueryCalledCount := 0 |
| 390 | tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { |
| 391 | traceBatchQueryCalledCount++ |
| 392 | require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) |
| 393 | if traceBatchQueryCalledCount == 2 { |
| 394 | require.Error(t, data.Err) |
| 395 | } else { |
| 396 | require.NoError(t, data.Err) |
| 397 | } |
| 398 | } |
| 399 | |
| 400 | traceBatchEndCalled := false |
| 401 | tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { |
| 402 | traceBatchEndCalled = true |
| 403 | require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) |
| 404 | require.Error(t, data.Err) |
| 405 | } |
| 406 | |
| 407 | batch := &pgx.Batch{} |
| 408 | batch.Queue(`select 1`) |
| 409 | batch.Queue(`select 2/n-2 from generate_series(0,10) n`) |
| 410 | batch.Queue(`select 3`) |
| 411 | |
| 412 | br := conn.SendBatch(context.Background(), batch) |
| 413 | require.True(t, traceBatchStartCalled) |
| 414 | err := br.Close() |
| 415 | require.Error(t, err) |
| 416 | require.EqualValues(t, 2, traceBatchQueryCalledCount) |
| 417 | require.True(t, traceBatchEndCalled) |
| 418 | }) |
| 419 | } |
| 420 | |
| 421 | func TestTraceCopyFrom(t *testing.T) { |
| 422 | t.Parallel() |